diff --git a/Stwo_wrapper/Cargo.toml b/Stwo_wrapper/Cargo.toml deleted file mode 100644 index 0f314a4..0000000 --- a/Stwo_wrapper/Cargo.toml +++ /dev/null @@ -1,22 +0,0 @@ -[workspace] -members = ["crates/prover"] -resolver = "2" - -[workspace.package] -version = "0.1.1" -edition = "2021" - -[workspace.dependencies] -blake2 = "0.10.6" -blake3 = "1.5.0" -educe = "0.5.0" -hex = "0.4.3" -itertools = "0.12.0" -num-traits = "0.2.17" -thiserror = "1.0.56" -bytemuck = "1.14.3" -tracing = "0.1.40" - -[profile.bench] -codegen-units = 1 -lto = true diff --git a/Stwo_wrapper/LICENSE b/Stwo_wrapper/LICENSE deleted file mode 100644 index 2e0cecd..0000000 --- a/Stwo_wrapper/LICENSE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright 2024 StarkWare Industries Ltd. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. \ No newline at end of file diff --git a/Stwo_wrapper/README.md b/Stwo_wrapper/README.md deleted file mode 100644 index ece38cc..0000000 --- a/Stwo_wrapper/README.md +++ /dev/null @@ -1,62 +0,0 @@ -
- -![STWO](resources/img/logo.png) - -GitHub Workflow Status (with event) - - - -Project license -StarkWare -
- -
-

- - Paper - - | - - Documentation - - | - - Benchmarks - -

-
- -# Stwo - -## 🌟 About - -Stwo is a next generation implementation of a [CSTARK](https://eprint.iacr.org/2024/278) prover and verifier, written in Rust 🦀. - -> **Stwo is a work in progress.** -> -> It is not recommended to use it in a production setting yet. - -## 🚀 Key Features - -- **Circle STARKs:** Based on the latest cryptographic research and innovations in the ZK field. -- **High performance:** Stwo is designed to be extremely fast and efficient. -- **Flexible:** Adaptable for various validity proof applications. - -## 📊 Benchmarks - -Run `poseidon_benchmark.sh` to run a single-threaded poseidon2 hash proof benchmark. - -Further benchmarks can be run using `cargo bench`. - -Visual representation of benchmarks can be found [here](https://starkware-libs.github.io/stwo/dev/bench/index.html). - -## 📜 License - -This project is licensed under the **Apache 2.0 license**. - -See [LICENSE](LICENSE) for more information. - - - - - diff --git a/Stwo_wrapper/WORKSPACE b/Stwo_wrapper/WORKSPACE deleted file mode 100644 index e69de29..0000000 diff --git a/Stwo_wrapper/crates/prover/Cargo.toml b/Stwo_wrapper/crates/prover/Cargo.toml deleted file mode 100644 index 587e655..0000000 --- a/Stwo_wrapper/crates/prover/Cargo.toml +++ /dev/null @@ -1,110 +0,0 @@ -[package] -name = "stwo-prover" -version.workspace = true -edition.workspace = true - -[features] -parallel = ["rayon"] -slow-tests = [] - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -blake2.workspace = true -blake3.workspace = true -bytemuck = { workspace = true, features = ["derive", "extern_crate_alloc"] } -cfg-if = "1.0.0" -downcast-rs = "1.2" -educe.workspace = true -hex.workspace = true -itertools.workspace = true -num-traits.workspace = true -rand = { version = "0.8.5", default-features = false, features = ["small_rng"] } -starknet-crypto = "0.6.2" -starknet-ff = "0.3.7" -ark-bls12-381 = "0.4.0" -ark-ff = "0.4.0" -thiserror.workspace = true -tracing.workspace = true -rayon = { version = "1.10.0", optional = true } -serde = { version = "1.0", features = ["derive"] } -crypto-bigint = "0.5.5" -ark-serialize = "0.4.0" -serde_json = "1.0.116" - -[dev-dependencies] -aligned = "0.4.2" -test-log = { version = "0.2.15", features = ["trace"] } -tracing-subscriber = "0.3.18" -[target.'cfg(all(target_family = "wasm", not(target_os = "wasi")))'.dev-dependencies] -wasm-bindgen-test = "0.3.43" - -[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies.criterion] -features = ["html_reports"] -version = "0.5.1" - -# Default features cause compile error: -# "Rayon cannot be used when targeting wasi32. Try disabling default features." -[target.'cfg(target_arch = "wasm32")'.dev-dependencies.criterion] -default-features = false -features = ["html_reports"] -version = "0.5.1" - -[lib] -bench = false -crate-type = ["cdylib", "lib"] - -[lints.rust] -warnings = "deny" -future-incompatible = "deny" -nonstandard-style = "deny" -rust-2018-idioms = "deny" -unused = "deny" - -[[bench]] -harness = false -name = "bit_rev" - -[[bench]] -harness = false -name = "eval_at_point" - -[[bench]] -harness = false -name = "fft" - -[[bench]] -harness = false -name = "field" - -[[bench]] -harness = false -name = "fri" - -[[bench]] -harness = false -name = "lookups" - -[[bench]] -harness = false -name = "matrix" - -[[bench]] -harness = false -name = "merkle" - -[[bench]] -harness = false -name = "poseidon" - -[[bench]] -harness = false -name = "prefix_sum" - -[[bench]] -harness = false -name = "quotients" - -[[bench]] -harness = false -name = "pcs" diff --git a/Stwo_wrapper/crates/prover/benches/README.md b/Stwo_wrapper/crates/prover/benches/README.md deleted file mode 100644 index 8e6d73f..0000000 --- a/Stwo_wrapper/crates/prover/benches/README.md +++ /dev/null @@ -1,2 +0,0 @@ -dev benchmark results can be seen at -https://starkware-libs.github.io/stwo/dev/bench/index.html diff --git a/Stwo_wrapper/crates/prover/benches/bit_rev.rs b/Stwo_wrapper/crates/prover/benches/bit_rev.rs deleted file mode 100644 index 6e287e6..0000000 --- a/Stwo_wrapper/crates/prover/benches/bit_rev.rs +++ /dev/null @@ -1,39 +0,0 @@ -#![feature(iter_array_chunks)] - -use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; -use itertools::Itertools; -use stwo_prover::core::fields::m31::BaseField; - -pub fn cpu_bit_rev(c: &mut Criterion) { - use stwo_prover::core::utils::bit_reverse; - // TODO(andrew): Consider using same size for all. - const SIZE: usize = 1 << 24; - let data = (0..SIZE).map(BaseField::from).collect_vec(); - c.bench_function("cpu bit_rev 24bit", |b| { - b.iter_batched( - || data.clone(), - |mut data| bit_reverse(&mut data), - BatchSize::LargeInput, - ); - }); -} - -pub fn simd_bit_rev(c: &mut Criterion) { - use stwo_prover::core::backend::simd::bit_reverse::bit_reverse_m31; - use stwo_prover::core::backend::simd::column::BaseColumn; - const SIZE: usize = 1 << 26; - let data = (0..SIZE).map(BaseField::from).collect::(); - c.bench_function("simd bit_rev 26bit", |b| { - b.iter_batched( - || data.data.clone(), - |mut data| bit_reverse_m31(&mut data), - BatchSize::LargeInput, - ); - }); -} - -criterion_group!( - name = bit_rev; - config = Criterion::default().sample_size(10); - targets = simd_bit_rev, cpu_bit_rev); -criterion_main!(bit_rev); diff --git a/Stwo_wrapper/crates/prover/benches/eval_at_point.rs b/Stwo_wrapper/crates/prover/benches/eval_at_point.rs deleted file mode 100644 index 64d1eec..0000000 --- a/Stwo_wrapper/crates/prover/benches/eval_at_point.rs +++ /dev/null @@ -1,35 +0,0 @@ -use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use rand::rngs::SmallRng; -use rand::{Rng, SeedableRng}; -use stwo_prover::core::backend::cpu::CpuBackend; -use stwo_prover::core::backend::simd::SimdBackend; -use stwo_prover::core::circle::CirclePoint; -use stwo_prover::core::fields::m31::BaseField; -use stwo_prover::core::poly::circle::{CirclePoly, PolyOps}; - -const LOG_SIZE: u32 = 20; - -fn bench_eval_at_secure_point(c: &mut Criterion, id: &str) { - let poly = CirclePoly::new((0..1 << LOG_SIZE).map(BaseField::from).collect()); - let mut rng = SmallRng::seed_from_u64(0); - let x = rng.gen(); - let y = rng.gen(); - let point = CirclePoint { x, y }; - c.bench_function( - &format!("{id} eval_at_secure_field_point 2^{LOG_SIZE}"), - |b| { - b.iter(|| B::eval_at_point(black_box(&poly), black_box(point))); - }, - ); -} - -fn eval_at_secure_point_benches(c: &mut Criterion) { - bench_eval_at_secure_point::(c, "simd"); - bench_eval_at_secure_point::(c, "cpu"); -} - -criterion_group!( - name = benches; - config = Criterion::default().sample_size(10); - targets = eval_at_secure_point_benches); -criterion_main!(benches); diff --git a/Stwo_wrapper/crates/prover/benches/fft.rs b/Stwo_wrapper/crates/prover/benches/fft.rs deleted file mode 100644 index 35841d7..0000000 --- a/Stwo_wrapper/crates/prover/benches/fft.rs +++ /dev/null @@ -1,131 +0,0 @@ -#![feature(iter_array_chunks)] - -use std::hint::black_box; -use std::mem::{size_of_val, transmute}; - -use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion, Throughput}; -use itertools::Itertools; -use stwo_prover::core::backend::simd::column::BaseColumn; -use stwo_prover::core::backend::simd::fft::ifft::{ - get_itwiddle_dbls, ifft, ifft3_loop, ifft_vecwise_loop, -}; -use stwo_prover::core::backend::simd::fft::rfft::{fft, get_twiddle_dbls}; -use stwo_prover::core::backend::simd::fft::transpose_vecs; -use stwo_prover::core::backend::simd::m31::PackedBaseField; -use stwo_prover::core::fields::m31::BaseField; -use stwo_prover::core::poly::circle::CanonicCoset; - -pub fn simd_ifft(c: &mut Criterion) { - let mut group = c.benchmark_group("iffts"); - - for log_size in 16..=28 { - let domain = CanonicCoset::new(log_size).circle_domain(); - let twiddle_dbls = get_itwiddle_dbls(domain.half_coset); - let twiddle_dbls_refs = twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(); - let values: BaseColumn = (0..domain.size()).map(BaseField::from).collect(); - group.throughput(Throughput::Bytes(size_of_val(&*values.data) as u64)); - group.bench_function(BenchmarkId::new("simd ifft", log_size), |b| { - b.iter_batched( - || values.clone().data, - |mut data| unsafe { - ifft( - transmute(data.as_mut_ptr()), - black_box(&twiddle_dbls_refs), - black_box(log_size as usize), - ); - }, - BatchSize::LargeInput, - ) - }); - } -} - -pub fn simd_ifft_parts(c: &mut Criterion) { - const LOG_SIZE: u32 = 14; - - let domain = CanonicCoset::new(LOG_SIZE).circle_domain(); - let twiddle_dbls = get_itwiddle_dbls(domain.half_coset); - let twiddle_dbls_refs = twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(); - let values: BaseColumn = (0..domain.size()).map(BaseField::from).collect(); - - let mut group = c.benchmark_group("ifft parts"); - - // Note: These benchmarks run only on 2^LOG_SIZE elements because of their parameters. - // Increasing the figure above won't change the runtime of these benchmarks. - group.throughput(Throughput::Bytes(4 << LOG_SIZE)); - group.bench_function(format!("simd ifft_vecwise_loop 2^{LOG_SIZE}"), |b| { - b.iter_batched( - || values.clone().data, - |mut values| unsafe { - ifft_vecwise_loop( - transmute(values.as_mut_ptr()), - black_box(&twiddle_dbls_refs), - black_box(9), - black_box(0), - ) - }, - BatchSize::LargeInput, - ); - }); - group.bench_function(format!("simd ifft3_loop 2^{LOG_SIZE}"), |b| { - b.iter_batched( - || values.clone().data, - |mut values| unsafe { - ifft3_loop( - transmute(values.as_mut_ptr()), - black_box(&twiddle_dbls_refs[3..]), - black_box(7), - black_box(4), - black_box(0), - ) - }, - BatchSize::LargeInput, - ); - }); - - const TRANSPOSE_LOG_SIZE: u32 = 20; - let transpose_values: BaseColumn = (0..1 << TRANSPOSE_LOG_SIZE).map(BaseField::from).collect(); - group.throughput(Throughput::Bytes(4 << TRANSPOSE_LOG_SIZE)); - group.bench_function(format!("simd transpose_vecs 2^{TRANSPOSE_LOG_SIZE}"), |b| { - b.iter_batched( - || transpose_values.clone().data, - |mut values| unsafe { - transpose_vecs( - transmute(values.as_mut_ptr()), - black_box(TRANSPOSE_LOG_SIZE as usize - 4), - ) - }, - BatchSize::LargeInput, - ); - }); -} - -pub fn simd_rfft(c: &mut Criterion) { - const LOG_SIZE: u32 = 20; - - let domain = CanonicCoset::new(LOG_SIZE).circle_domain(); - let twiddle_dbls = get_twiddle_dbls(domain.half_coset); - let twiddle_dbls_refs = twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(); - let values: BaseColumn = (0..domain.size()).map(BaseField::from).collect(); - - c.bench_function("simd rfft 20bit", |b| { - b.iter_with_large_drop(|| unsafe { - let mut target = Vec::::with_capacity(values.data.len()); - #[allow(clippy::uninit_vec)] - target.set_len(values.data.len()); - - fft( - black_box(transmute(values.data.as_ptr())), - transmute(target.as_mut_ptr()), - black_box(&twiddle_dbls_refs), - black_box(LOG_SIZE as usize), - ) - }); - }); -} - -criterion_group!( - name = benches; - config = Criterion::default().sample_size(10); - targets = simd_ifft, simd_ifft_parts, simd_rfft); -criterion_main!(benches); diff --git a/Stwo_wrapper/crates/prover/benches/field.rs b/Stwo_wrapper/crates/prover/benches/field.rs deleted file mode 100644 index acb318c..0000000 --- a/Stwo_wrapper/crates/prover/benches/field.rs +++ /dev/null @@ -1,150 +0,0 @@ -use criterion::{criterion_group, criterion_main, Criterion}; -use num_traits::One; -use rand::rngs::SmallRng; -use rand::{Rng, SeedableRng}; -use stwo_prover::core::backend::simd::m31::{PackedBaseField, N_LANES}; -use stwo_prover::core::fields::cm31::CM31; -use stwo_prover::core::fields::m31::{BaseField, M31}; -use stwo_prover::core::fields::qm31::SecureField; - -pub const N_ELEMENTS: usize = 1 << 16; -pub const N_STATE_ELEMENTS: usize = 8; - -pub fn m31_operations_bench(c: &mut Criterion) { - let mut rng = SmallRng::seed_from_u64(0); - let elements: Vec = (0..N_ELEMENTS).map(|_| rng.gen()).collect(); - let mut state: [M31; N_STATE_ELEMENTS] = rng.gen(); - - c.bench_function("M31 mul", |b| { - b.iter(|| { - for elem in &elements { - for _ in 0..128 { - for state_elem in &mut state { - *state_elem *= *elem; - } - } - } - }) - }); - - c.bench_function("M31 add", |b| { - b.iter(|| { - for elem in &elements { - for _ in 0..128 { - for state_elem in &mut state { - *state_elem += *elem; - } - } - } - }) - }); -} - -pub fn cm31_operations_bench(c: &mut Criterion) { - let mut rng = SmallRng::seed_from_u64(0); - let elements: Vec = (0..N_ELEMENTS).map(|_| rng.gen()).collect(); - let mut state: [CM31; N_STATE_ELEMENTS] = rng.gen(); - - c.bench_function("CM31 mul", |b| { - b.iter(|| { - for elem in &elements { - for _ in 0..128 { - for state_elem in &mut state { - *state_elem *= *elem; - } - } - } - }) - }); - - c.bench_function("CM31 add", |b| { - b.iter(|| { - for elem in &elements { - for _ in 0..128 { - for state_elem in &mut state { - *state_elem += *elem; - } - } - } - }) - }); -} - -pub fn qm31_operations_bench(c: &mut Criterion) { - let mut rng = SmallRng::seed_from_u64(0); - let elements: Vec = (0..N_ELEMENTS).map(|_| rng.gen()).collect(); - let mut state: [SecureField; N_STATE_ELEMENTS] = rng.gen(); - - c.bench_function("SecureField mul", |b| { - b.iter(|| { - for elem in &elements { - for _ in 0..128 { - for state_elem in &mut state { - *state_elem *= *elem; - } - } - } - }) - }); - - c.bench_function("SecureField add", |b| { - b.iter(|| { - for elem in &elements { - for _ in 0..128 { - for state_elem in &mut state { - *state_elem += *elem; - } - } - } - }) - }); -} - -pub fn simd_m31_operations_bench(c: &mut Criterion) { - let mut rng = SmallRng::seed_from_u64(0); - let elements: Vec = (0..N_ELEMENTS / N_LANES).map(|_| rng.gen()).collect(); - let mut states = vec![PackedBaseField::broadcast(BaseField::one()); N_STATE_ELEMENTS]; - - c.bench_function("mul_simd", |b| { - b.iter(|| { - for elem in elements.iter() { - for _ in 0..128 { - for state in states.iter_mut() { - *state *= *elem; - } - } - } - }) - }); - - c.bench_function("add_simd", |b| { - b.iter(|| { - for elem in elements.iter() { - for _ in 0..128 { - for state in states.iter_mut() { - *state += *elem; - } - } - } - }) - }); - - c.bench_function("sub_simd", |b| { - b.iter(|| { - for elem in elements.iter() { - for _ in 0..128 { - for state in states.iter_mut() { - *state -= *elem; - } - } - } - }) - }); -} - -criterion_group!( - name = benches; - config = Criterion::default().sample_size(10); - targets = m31_operations_bench, cm31_operations_bench, qm31_operations_bench, - simd_m31_operations_bench); -criterion_main!(benches); diff --git a/Stwo_wrapper/crates/prover/benches/fri.rs b/Stwo_wrapper/crates/prover/benches/fri.rs deleted file mode 100644 index 1c38a0e..0000000 --- a/Stwo_wrapper/crates/prover/benches/fri.rs +++ /dev/null @@ -1,35 +0,0 @@ -use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use stwo_prover::core::backend::CpuBackend; -use stwo_prover::core::fields::m31::BaseField; -use stwo_prover::core::fields::qm31::SecureField; -use stwo_prover::core::fields::secure_column::SecureColumnByCoords; -use stwo_prover::core::fri::FriOps; -use stwo_prover::core::poly::circle::{CanonicCoset, PolyOps}; -use stwo_prover::core::poly::line::{LineDomain, LineEvaluation}; - -fn folding_benchmark(c: &mut Criterion) { - const LOG_SIZE: u32 = 12; - let domain = LineDomain::new(CanonicCoset::new(LOG_SIZE + 1).half_coset()); - let evals = LineEvaluation::new( - domain, - SecureColumnByCoords { - columns: std::array::from_fn(|i| { - vec![BaseField::from_u32_unchecked(i as u32); 1 << LOG_SIZE] - }), - }, - ); - let alpha = SecureField::from_u32_unchecked(2213980, 2213981, 2213982, 2213983); - let twiddles = CpuBackend::precompute_twiddles(domain.coset()); - c.bench_function("fold_line", |b| { - b.iter(|| { - black_box(CpuBackend::fold_line( - black_box(&evals), - black_box(alpha), - &twiddles, - )); - }) - }); -} - -criterion_group!(benches, folding_benchmark); -criterion_main!(benches); diff --git a/Stwo_wrapper/crates/prover/benches/lookups.rs b/Stwo_wrapper/crates/prover/benches/lookups.rs deleted file mode 100644 index ac45a95..0000000 --- a/Stwo_wrapper/crates/prover/benches/lookups.rs +++ /dev/null @@ -1,104 +0,0 @@ -use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; -use rand::distributions::{Distribution, Standard}; -use rand::rngs::SmallRng; -use rand::{Rng, SeedableRng}; -use stwo_prover::core::backend::simd::SimdBackend; -use stwo_prover::core::backend::CpuBackend; -use stwo_prover::core::channel::Blake2sChannel; -use stwo_prover::core::fields::Field; -use stwo_prover::core::lookups::gkr_prover::{prove_batch, GkrOps, Layer}; -use stwo_prover::core::lookups::mle::{Mle, MleOps}; - -const LOG_N_ROWS: u32 = 16; - -fn bench_gkr_grand_product(c: &mut Criterion, id: &str) { - let mut rng = SmallRng::seed_from_u64(0); - let layer = Layer::::GrandProduct(gen_random_mle(&mut rng, LOG_N_ROWS)); - c.bench_function(&format!("{id} grand product lookup 2^{LOG_N_ROWS}"), |b| { - b.iter_batched( - || layer.clone(), - |layer| prove_batch(&mut Blake2sChannel::default(), vec![layer]), - BatchSize::LargeInput, - ) - }); - c.bench_function( - &format!("{id} grand product lookup batch 4x 2^{LOG_N_ROWS}"), - |b| { - b.iter_batched( - || vec![layer.clone(), layer.clone(), layer.clone(), layer.clone()], - |layers| prove_batch(&mut Blake2sChannel::default(), layers), - BatchSize::LargeInput, - ) - }, - ); -} - -fn bench_gkr_logup_generic(c: &mut Criterion, id: &str) { - let mut rng = SmallRng::seed_from_u64(0); - let generic_layer = Layer::::LogUpGeneric { - numerators: gen_random_mle(&mut rng, LOG_N_ROWS), - denominators: gen_random_mle(&mut rng, LOG_N_ROWS), - }; - c.bench_function(&format!("{id} generic logup lookup 2^{LOG_N_ROWS}"), |b| { - b.iter_batched( - || generic_layer.clone(), - |layer| prove_batch(&mut Blake2sChannel::default(), vec![layer]), - BatchSize::LargeInput, - ) - }); -} - -fn bench_gkr_logup_multiplicities(c: &mut Criterion, id: &str) { - let mut rng = SmallRng::seed_from_u64(0); - let multiplicities_layer = Layer::::LogUpMultiplicities { - numerators: gen_random_mle(&mut rng, LOG_N_ROWS), - denominators: gen_random_mle(&mut rng, LOG_N_ROWS), - }; - c.bench_function( - &format!("{id} multiplicities logup lookup 2^{LOG_N_ROWS}"), - |b| { - b.iter_batched( - || multiplicities_layer.clone(), - |layer| prove_batch(&mut Blake2sChannel::default(), vec![layer]), - BatchSize::LargeInput, - ) - }, - ); -} - -fn bench_gkr_logup_singles(c: &mut Criterion, id: &str) { - let mut rng = SmallRng::seed_from_u64(0); - let singles_layer = Layer::::LogUpSingles { - denominators: gen_random_mle(&mut rng, LOG_N_ROWS), - }; - c.bench_function(&format!("{id} singles logup lookup 2^{LOG_N_ROWS}"), |b| { - b.iter_batched( - || singles_layer.clone(), - |layer| prove_batch(&mut Blake2sChannel::default(), vec![layer]), - BatchSize::LargeInput, - ) - }); -} - -/// Generates a random multilinear polynomial. -fn gen_random_mle, F: Field>(rng: &mut impl Rng, n_variables: u32) -> Mle -where - Standard: Distribution, -{ - Mle::new((0..1 << n_variables).map(|_| rng.gen()).collect()) -} - -fn gkr_lookup_benches(c: &mut Criterion) { - bench_gkr_grand_product::(c, "simd"); - bench_gkr_logup_generic::(c, "simd"); - bench_gkr_logup_multiplicities::(c, "simd"); - bench_gkr_logup_singles::(c, "simd"); - - bench_gkr_grand_product::(c, "cpu"); - bench_gkr_logup_generic::(c, "cpu"); - bench_gkr_logup_multiplicities::(c, "cpu"); - bench_gkr_logup_singles::(c, "cpu"); -} - -criterion_group!(benches, gkr_lookup_benches); -criterion_main!(benches); diff --git a/Stwo_wrapper/crates/prover/benches/matrix.rs b/Stwo_wrapper/crates/prover/benches/matrix.rs deleted file mode 100644 index 8e44a98..0000000 --- a/Stwo_wrapper/crates/prover/benches/matrix.rs +++ /dev/null @@ -1,63 +0,0 @@ -use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use rand::rngs::SmallRng; -use rand::{Rng, SeedableRng}; -use stwo_prover::core::fields::m31::{M31, P}; -use stwo_prover::core::fields::qm31::QM31; -use stwo_prover::math::matrix::{RowMajorMatrix, SquareMatrix}; - -const MATRIX_SIZE: usize = 24; -const QM31_MATRIX_SIZE: usize = 6; - -// TODO(ShaharS): Share code with other benchmarks. -fn row_major_matrix_multiplication_bench(c: &mut Criterion) { - let mut rng = SmallRng::seed_from_u64(0); - - let matrix_m31 = RowMajorMatrix::::new( - (0..MATRIX_SIZE.pow(2)) - .map(|_| rng.gen()) - .collect::>(), - ); - - let matrix_qm31 = RowMajorMatrix::::new( - (0..QM31_MATRIX_SIZE.pow(2)) - .map(|_| rng.gen()) - .collect::>(), - ); - - // Create vector M31. - let vec: [M31; MATRIX_SIZE] = rng.gen(); - - // Create vector QM31. - let vec_qm31: [QM31; QM31_MATRIX_SIZE] = [(); QM31_MATRIX_SIZE].map(|_| { - QM31::from_u32_unchecked( - rng.gen::() % P, - rng.gen::() % P, - rng.gen::() % P, - rng.gen::() % P, - ) - }); - - // bench matrix multiplication. - c.bench_function( - &format!("RowMajorMatrix M31 {size}x{size} mul", size = MATRIX_SIZE), - |b| { - b.iter(|| { - black_box(matrix_m31.mul(vec)); - }) - }, - ); - c.bench_function( - &format!( - "QM31 RowMajorMatrix {size}x{size} mul", - size = QM31_MATRIX_SIZE - ), - |b| { - b.iter(|| { - black_box(matrix_qm31.mul(vec_qm31)); - }) - }, - ); -} - -criterion_group!(benches, row_major_matrix_multiplication_bench); -criterion_main!(benches); diff --git a/Stwo_wrapper/crates/prover/benches/merkle.rs b/Stwo_wrapper/crates/prover/benches/merkle.rs deleted file mode 100644 index c039be7..0000000 --- a/Stwo_wrapper/crates/prover/benches/merkle.rs +++ /dev/null @@ -1,38 +0,0 @@ -#![feature(iter_array_chunks)] - -use criterion::{criterion_group, criterion_main, Criterion, Throughput}; -use itertools::Itertools; -use num_traits::Zero; -use stwo_prover::core::backend::simd::SimdBackend; -use stwo_prover::core::backend::{Col, CpuBackend}; -use stwo_prover::core::fields::m31::{BaseField, N_BYTES_FELT}; -use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleHasher; -use stwo_prover::core::vcs::ops::MerkleOps; - -const LOG_N_ROWS: u32 = 16; - -const LOG_N_COLS: u32 = 8; - -fn bench_blake2s_merkle>(c: &mut Criterion, id: &str) { - let col: Col = (0..1 << LOG_N_ROWS).map(|_| BaseField::zero()).collect(); - let cols = (0..1 << LOG_N_COLS).map(|_| col.clone()).collect_vec(); - let col_refs = cols.iter().collect_vec(); - let mut group = c.benchmark_group("merkle throughput"); - let n_elements = 1 << (LOG_N_COLS + LOG_N_ROWS); - group.throughput(Throughput::Elements(n_elements)); - group.throughput(Throughput::Bytes(N_BYTES_FELT as u64 * n_elements)); - group.bench_function(&format!("{id} merkle"), |b| { - b.iter_with_large_drop(|| B::commit_on_layer(LOG_N_ROWS, None, &col_refs)) - }); -} - -fn blake2s_merkle_benches(c: &mut Criterion) { - bench_blake2s_merkle::(c, "simd"); - bench_blake2s_merkle::(c, "cpu"); -} - -criterion_group!( - name = benches; - config = Criterion::default().sample_size(10); - targets = blake2s_merkle_benches); -criterion_main!(benches); diff --git a/Stwo_wrapper/crates/prover/benches/pcs.rs b/Stwo_wrapper/crates/prover/benches/pcs.rs deleted file mode 100644 index da185d7..0000000 --- a/Stwo_wrapper/crates/prover/benches/pcs.rs +++ /dev/null @@ -1,81 +0,0 @@ -use std::iter; - -use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion}; -use rand::rngs::SmallRng; -use rand::{Rng, SeedableRng}; -use stwo_prover::core::backend::simd::SimdBackend; -use stwo_prover::core::backend::{BackendForChannel, CpuBackend}; -use stwo_prover::core::channel::Blake2sChannel; -use stwo_prover::core::fields::m31::BaseField; -use stwo_prover::core::pcs::CommitmentTreeProver; -use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation}; -use stwo_prover::core::poly::twiddles::TwiddleTree; -use stwo_prover::core::poly::BitReversedOrder; -use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel; - -const LOG_COSET_SIZE: u32 = 20; -const LOG_BLOWUP_FACTOR: u32 = 1; -const N_POLYS: usize = 16; - -fn benched_fn>( - evals: Vec>, - channel: &mut Blake2sChannel, - twiddles: &TwiddleTree, -) { - let polys = evals - .into_iter() - .map(|eval| eval.interpolate_with_twiddles(twiddles)) - .collect(); - - CommitmentTreeProver::::new( - polys, - LOG_BLOWUP_FACTOR, - channel, - twiddles, - ); -} - -fn bench_pcs>(c: &mut Criterion, id: &str) { - let small_domain = CanonicCoset::new(LOG_COSET_SIZE); - let big_domain = CanonicCoset::new(LOG_COSET_SIZE + LOG_BLOWUP_FACTOR); - let twiddles = B::precompute_twiddles(big_domain.half_coset()); - let mut channel = Blake2sChannel::default(); - let mut rng = SmallRng::seed_from_u64(0); - - let evals: Vec> = iter::repeat_with(|| { - CircleEvaluation::new( - small_domain.circle_domain(), - (0..1 << LOG_COSET_SIZE).map(|_| rng.gen()).collect(), - ) - }) - .take(N_POLYS) - .collect(); - - c.bench_function( - &format!("{id} polynomial commitment 2^{LOG_COSET_SIZE}"), - |b| { - b.iter_batched( - || evals.clone(), - |evals| { - benched_fn::( - black_box(evals), - black_box(&mut channel), - black_box(&twiddles), - ) - }, - BatchSize::LargeInput, - ); - }, - ); -} - -fn pcs_benches(c: &mut Criterion) { - bench_pcs::(c, "simd"); - bench_pcs::(c, "cpu"); -} - -criterion_group!( - name = benches; - config = Criterion::default().sample_size(10); - targets = pcs_benches); -criterion_main!(benches); diff --git a/Stwo_wrapper/crates/prover/benches/poseidon.rs b/Stwo_wrapper/crates/prover/benches/poseidon.rs deleted file mode 100644 index bc796c6..0000000 --- a/Stwo_wrapper/crates/prover/benches/poseidon.rs +++ /dev/null @@ -1,18 +0,0 @@ -use criterion::{criterion_group, criterion_main, Criterion, Throughput}; -use stwo_prover::core::pcs::PcsConfig; -use stwo_prover::examples::poseidon::prove_poseidon; - -pub fn simd_poseidon(c: &mut Criterion) { - const LOG_N_INSTANCES: u32 = 18; - let mut group = c.benchmark_group("poseidon2"); - group.throughput(Throughput::Elements(1u64 << LOG_N_INSTANCES)); - group.bench_function(format!("poseidon2 2^{} instances", LOG_N_INSTANCES), |b| { - b.iter(|| prove_poseidon(LOG_N_INSTANCES, PcsConfig::default())); - }); -} - -criterion_group!( - name = bit_rev; - config = Criterion::default().sample_size(10); - targets = simd_poseidon); -criterion_main!(bit_rev); diff --git a/Stwo_wrapper/crates/prover/benches/prefix_sum.rs b/Stwo_wrapper/crates/prover/benches/prefix_sum.rs deleted file mode 100644 index 7faf4ac..0000000 --- a/Stwo_wrapper/crates/prover/benches/prefix_sum.rs +++ /dev/null @@ -1,19 +0,0 @@ -use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; -use stwo_prover::core::backend::simd::column::BaseColumn; -use stwo_prover::core::backend::simd::prefix_sum::inclusive_prefix_sum; -use stwo_prover::core::fields::m31::BaseField; - -pub fn simd_prefix_sum_bench(c: &mut Criterion) { - const LOG_SIZE: u32 = 24; - let evals: BaseColumn = (0..1 << LOG_SIZE).map(BaseField::from).collect(); - c.bench_function(&format!("simd prefix_sum 2^{LOG_SIZE}"), |b| { - b.iter_batched( - || evals.clone(), - inclusive_prefix_sum, - BatchSize::LargeInput, - ); - }); -} - -criterion_group!(benches, simd_prefix_sum_bench); -criterion_main!(benches); diff --git a/Stwo_wrapper/crates/prover/benches/quotients.rs b/Stwo_wrapper/crates/prover/benches/quotients.rs deleted file mode 100644 index fc2949a..0000000 --- a/Stwo_wrapper/crates/prover/benches/quotients.rs +++ /dev/null @@ -1,55 +0,0 @@ -#![feature(iter_array_chunks)] - -use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use itertools::Itertools; -use stwo_prover::core::backend::cpu::CpuBackend; -use stwo_prover::core::backend::simd::SimdBackend; -use stwo_prover::core::circle::SECURE_FIELD_CIRCLE_GEN; -use stwo_prover::core::fields::m31::BaseField; -use stwo_prover::core::fields::qm31::SecureField; -use stwo_prover::core::pcs::quotients::{ColumnSampleBatch, QuotientOps}; -use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation}; -use stwo_prover::core::poly::BitReversedOrder; - -// TODO(andrew): Consider removing const generics and making all sizes the same. -fn bench_quotients( - c: &mut Criterion, - id: &str, -) { - let domain = CanonicCoset::new(LOG_N_ROWS).circle_domain(); - let values = (0..domain.size()).map(BaseField::from).collect(); - let col = CircleEvaluation::::new(domain, values); - let cols = (0..1 << LOG_N_COLS).map(|_| col.clone()).collect_vec(); - let col_refs = cols.iter().collect_vec(); - let random_coeff = SecureField::from_u32_unchecked(0, 1, 2, 3); - let a = SecureField::from_u32_unchecked(5, 6, 7, 8); - let samples = vec![ColumnSampleBatch { - point: SECURE_FIELD_CIRCLE_GEN, - columns_and_values: (0..1 << LOG_N_COLS).map(|i| (i, a)).collect(), - }]; - c.bench_function( - &format!("{id} quotients 2^{LOG_N_COLS} x 2^{LOG_N_ROWS}"), - |b| { - b.iter_with_large_drop(|| { - B::accumulate_quotients( - black_box(domain), - black_box(&col_refs), - black_box(random_coeff), - black_box(&samples), - 1, - ) - }) - }, - ); -} - -fn quotients_benches(c: &mut Criterion) { - bench_quotients::(c, "simd"); - bench_quotients::(c, "cpu"); -} - -criterion_group!( - name = benches; - config = Criterion::default().sample_size(10); - targets = quotients_benches); -criterion_main!(benches); diff --git a/Stwo_wrapper/crates/prover/proof.json b/Stwo_wrapper/crates/prover/proof.json deleted file mode 100644 index eef777e..0000000 --- a/Stwo_wrapper/crates/prover/proof.json +++ /dev/null @@ -1,348 +0,0 @@ -{ - "commitments" : - ["34328580272026076035687604093297365442785733592720865218001799813393342152908", - "38388381845372648579572899115609862601821983406101214230086519922780265042634"], - - "sampled_values_0" : - [["1","0","0","0"], - ["2129160320","1109509513","787887008","1676461964"], - ["262908602","915488457","1893945291","1774327476"], - ["894719153","1570509766","1424186619","204092576"], - ["397490811","836398274","1615765624","2013800563"], - ["1022303904","276983775","1064742229","165204856"], - ["1200363525","170838026","524999776","156116441"], - ["850733526","448725560","1521962209","1318190714"], - ["1187866075","1705588092","924088348","490002418"], - ["2033565088","996780784","1820235518","2048788344"], - ["2061590372","1150986157","711772586","1511398564"], - ["1066623954","530384603","1890251380","1699008129"], - ["734047580","1685768538","505142109","787113212"], - ["2030904700","99932423","695391286","1736941035"], - ["1580330105","932031717","1705998668","146411959"], - ["1585732224","1556242253","941668238","1998570239"], - ["199481433","2123320403","1257464748","1663811899"], - ["2139019524","1547107722","728449250","1941851166"], - ["752079023","268472135","1465850435","16510773"], - ["1279312817","63252415","442230579","1560954631"], - ["1074859131","137997593","2118329011","652535723"], - ["297567647","1483381078","1941495981","599737348"], - ["1735543786","1420676479","1354982762","1114211268"], - ["1691705401","1143446295","1748115479","1666756627"], - ["955696743","2077778309","736065989","1319443838"], - ["1076874307","1001483910","1702287354","819727011"], - ["1134989244","1823710400","2067694105","1098263343"], - ["1793642608","961404475","1279773056","1815400043"], - ["739677274","1827877577","838562378","171296720"], - ["2036367121","1901888610","289723252","2014426907"], - ["330020507","436937516","2113056521","1828501207"], - ["1359068814","583899921","734628376","1223217137"], - ["1319501520","1242972089","1202216521","1285024997"], - ["681182370","1569622309","1574376904","1563950435"], - ["1204519566","483612224","1677731115","1667757584"], - ["330284364","917877098","57538161","179869993"], - ["2056561198","119768893","740294154","1454562198"], - ["79009084","545196641","13388962","1973400144"], - ["885977898","1973300145","37115619","957100699"], - ["1937449867","1777683674","1983002799","757662558"], - ["344927561","357845689","26887161","664585634"], - ["1462268220","615463524","209500386","44308852"], - ["570984705","2022111132","1404632615","2119081660"], - ["13183327","1584451280","1216116653","316345540"], - ["1497965915","705236857","1892466476","1068567492"], - ["1758694676","1408790161","1140545981","315723937"], - ["645308461","1125824784","1786470558","1240927727"], - ["1213464061","470930291","1718629724","1149088875"], - ["214577693","1578610321","2133720991","226291629"], - ["1357706729","2097875841","1767996253","1478111500"], - ["1154658683","752162439","2018723944","163997560"], - ["1051993583","703716977","379706674","487262860"], - ["1017692573","2060296775","2001023083","1064213951"], - ["1042587725","1701108370","204550428","904590130"], - ["1115340870","743420370","1927225111","1276396551"], - ["493638626","1874789377","47342513","209203758"], - ["1558586505","83459476","247638703","1975504267"], - ["2097068784","954319448","367516919","1545761518"], - ["1655645294","352838520","1307263981","1110198118"], - ["1169856046","1925368371","1362317240","1926032147"], - ["1940113709","885624001","1395047654","80053995"], - ["1778932990","25092730","201117282","1724571908"], - ["2096327738","233411984","1247443120","713989449"], - ["808532602","136577890","1015579288","38900716"], - ["1182257782","1186245376","1451332036","2080170103"], - ["1662610758","1505542080","1038243031","1889715771"], - ["440146119","942837214","1440484295","1593949278"], - ["46258268","1884246120","164930024","2050584510"], - ["1198954868","1079638495","1424072583","1028611344"], - ["2112984649","1531382496","1873151714","1818301795"], - ["1554382282","253920307","1641628530","1378998084"], - ["857898234","686236793","2091871553","184978860"], - ["2049153599","6111471","1579475775","32492894"], - ["1371356596","679072793","1547377985","354305233"], - ["1799882226","1201472049","1592617716","125534957"], - ["1277144880","253726080","1800145982","1125162267"], - ["1577717920","440984421","1377891036","846453148"], - ["1952731919","1710992214","673668053","1871913638"], - ["1559011028","2060945859","719954448","1356468891"], - ["1961642242","1693473944","1300152522","412222111"], - ["861208187","1242659514","977183954","38730935"], - ["1016984917","1368361439","2106430139","1225979890"], - ["1427754325","1482206106","1465316380","1096279813"], - ["566051043","2025874544","234976335","1482256978"], - ["1750543495","1494541462","374330732","411642241"], - ["230654343","55625728","136463431","1099606808"], - ["1172218793","1260458608","1314942990","75527287"], - ["1824515276","916178746","1300275105","370626746"], - ["915931367","987018043","56193044","617907884"], - ["1934695822","1112844637","609268252","1972086910"], - ["619631651","152029630","1979976905","292597437"], - ["62258350","1890115432","1373605674","1505619938"], - ["1770422019","1398189304","1773172351","1576001433"], - ["650940868","1756047014","1764798953","1146887875"], - ["1746945043","528205234","778346028","1797468521"], - ["760802416","1479409742","1556974632","1307498378"], - ["102511022","1787975482","968854748","1010240763"], - ["330722054","2046294448","14132125","1822414050"], - ["943548871","1770900623","1861740461","1290634078"], - ["1402661415","1361511065","1784889120","837615360"]], - - "sampled_values_1" : - [["712066144","1576368753","626134398","426337436"], - ["160634493","1096735733","992622982","964509862"], - ["208900621","1128739590","1423579079","1688318061"], - ["1029182234","1152361165","571476481","1593867154"]], - - "decommitment_0" : - ["24311567319749512546399129581715033328970605051392227451685196018312506896509", - "8450134967305372517473027560161707471995673370792264153422077885080332622841", - "6431507699794114682519586182713221908058047520896405293833270087934517909753", - "303109001984349840640377328716025252051982378448629744935456455431709129012", - "47167328465744900593371601186109726758160197572292632388959155138584359581158", - "50584492046778480438774038937088410409133167768957478525289857065775850658491", - "30584798499699103841624545814425958941934653399588880797257122471101102880636", - "33441256878213890325682161124370878299436204406591246133637659120215439522803", - "3288124068330032280185519028600654292250668929588668389702892483946668251740", - "29852774919556057664485671676242264613416836486089146650713214180894511265116", - "12482060975231949385592255321766253365687502822944549845564491620341379321204", - "46285234163162336949700608657672147469543559995399282843606812790099228411758", - "20807128972645591294726020136444795908525656782422245307591812614900798799914"], - - "decommitment_1" : - ["18063303111481257844109225560025890393366258018933166919604543575686388632162", - "1676364734386980395984608216327451243278421019544108756198322792517099196249", - "5278661052518480850653886996628582549184134231869598116690316714367933376948", - "21983822689977371558234298346357617674436224016274009820764238516520240403273", - "605332543427376153930374757063581881998320956602375739165671986207155079359", - "33771702906041565783498389127165108212044382608172583325407071671862086994048", - "6930451780154275491146135719028766497977496109537963233244808739657647563071", - "4564117668410212714684125903928600765456322272915099403587425756488534507713", - "337808767671877648828499299861821973796749820854854708379479049898835100991", - "2725840354457305623692571800192492803162041315546256970381708693201407812833", - "34716495111790106826563330917176360656701717867702196654990744866499300990003", - "49719870445464463616785535809529171382800153139923763422202182182379572350737", - "48664746464641275030915461677298150155193593333108431337941029583245720868695", - "32557886668297237033601675259512842580727821006475208499640136322794706303894", - "15014835402414421167586357788116276188694467622586221351644991310645286648480", - "12973705814659120511327850727547995427054971555827754777395732787633567627149", - "15379912850866472398958956306527914195058439699787840041152620034933267404138", - "19859070819439084101412868355121176941090844577507824922960697697668791429525", - "38273559034692361632775489704953448699371080776239846995670381153186834620044"], - - "queried_values_0" : - [["1","1","1","1","1","1"], - ["730457281","730490049","28918683","28885915","1656126010","1656093242"], - ["855614122","1238037465","1836504291","355791428","757095818","467806903"], - ["674179888","1530445315","1720543014","76190330","1475912409","1017215862"], - ["1142290008","1148671853","1619097781","938511401","904357795","257652679"], - ["1679234056","1355264641","2139729457","574756654","604307234","1146556949"], - ["1000500309","2008905806","1442759180","598876729","1786070690","1072293976"], - ["1119085545","2133345582","135683580","216214405","1049766224","943727969"], - ["206423262","2047139937","305085364","1422472664","1826554088","1032095092"], - ["501238882","1656305868","724710382","1949772461","1426787917","585368894"], - ["1005468045","1775577441","1042182076","415631363","1067013227","1635705270"], - ["1776076392","216798814","1525036520","1160666510","1212132211","1915058776"], - ["859923105","1633989410","182110635","2060185314","1084464822","1129902257"], - ["489437802","313401022","271315129","357612175","2050381179","647577687"], - ["1495302158","2052264981","1498165299","1164417520","1050104037","450244199"], - ["1084986392","398966983","808449145","1733554138","2068501028","659474347"], - ["399458768","1789245133","1698759035","188433436","1794535430","364419824"], - ["2013965647","722839714","928854328","124488895","1378959529","952886009"], - ["1334765706","193402268","471076108","640800921","1998121783","961582406"], - ["1067762968","381831281","560459357","1025929344","181659877","1922040224"], - ["1993303462","467991218","849673597","744722836","239634354","329631295"], - ["785794488","1649178388","672964420","1281255462","900602801","271501809"], - ["857859728","1325395820","985014020","1094321795","259553347","774587048"], - ["1214640090","1588569866","871717820","1131833706","1625896842","1635087550"], - ["796549205","931495223","2018253108","1395065060","158209751","1160478135"], - ["883143962","729115354","190207821","839273168","1668931939","2074584689"], - ["1490296658","1846956206","1610364850","56422972","160482417","681872093"], - ["1270585092","1910190167","464113273","613529242","1027101122","1014185686"], - ["1456043179","1999662961","193940913","678382864","39040067","1236859818"], - ["1626243617","901735777","1703169024","911300891","1640727682","1121874896"], - ["492192896","15672698","319327174","1727120334","1965889437","114404366"], - ["407079019","949462637","255390508","1753162095","501134776","1457122467"], - ["1478573872","1439193434","1053200675","1001140887","1553935777","1253681552"], - ["183135520","946237525","1802924023","1831496784","1893117930","1830486286"], - ["234902670","1169030504","196055115","1323151968","855748623","1328842866"], - ["1150999776","1338824346","2072101698","774206263","1967350016","1808817867"], - ["924341552","1430286424","511268814","825025920","1061850574","1954646566"], - ["302634890","314434153","1692670768","1822915313","1244352075","1953834230"], - ["1576167467","687837005","2116136752","144109400","1590157548","1634932462"], - ["396756275","1272134898","1207308240","818219166","1314182589","109494000"], - ["846425160","897737569","757312164","826009489","1019831588","1977463051"], - ["2065801114","1918982367","1548689186","2082631803","298112070","383438809"], - ["1034102289","461735180","2115581275","1343026598","1229979058","1021418523"], - ["1784173874","166635387","547550115","1094693960","573193735","451367040"], - ["119818313","659105018","1741377697","8940733","911200334","511474518"], - ["1511949880","1315119529","1267019200","2134944693","878254810","375758264"], - ["1203050254","156394547","1348568635","412863443","1068659960","1407913814"], - ["361779719","130417374","89109096","117994876","1151322919","863143484"], - ["1007476533","989566160","138644964","1672742874","540141118","1296408100"], - ["1824241144","2051199719","1863718547","2109877864","36689613","1055926854"], - ["791693003","1433717239","991140958","1565955371","1839976870","1163947838"], - ["1267320759","2102593211","1831360854","1691591439","1672201908","61327345"], - ["301343164","277158258","627925439","577508975","1896464649","907629062"], - ["1964268932","929590164","1529686876","68630644","1663063136","254082844"], - ["693529348","1815295486","1660565870","1226857377","156312343","1500907098"], - ["1723158753","252348225","253985470","52424437","1605949937","576572581"], - ["1781792048","1497492716","1951824572","1156925855","863650708","156447987"], - ["876432605","458503399","283092867","247883110","1227074181","966219235"], - ["1581118191","66527915","1577039825","1227961402","1738412997","1862462297"], - ["679458448","338624032","34185999","253532412","65409631","563033132"], - ["1011967612","1898273226","2124401156","105282260","1188226330","123913515"], - ["1432513005","1083162825","1299704150","1184276814","339749370","1064298821"], - ["145070927","1250457746","1977306722","1035124433","322154361","1782232869"], - ["1934256728","1269667423","248999825","267009036","132507662","474021315"], - ["1694871453","535187899","981724703","1180312550","74370795","277702656"], - ["1850841912","1037634121","1967377497","1755127193","1566449422","1939039785"], - ["1315699598","476111459","2058733537","332263289","1592057567","874912616"], - ["774536537","170060616","2086574090","47894465","778021586","2115296942"], - ["322468558","24934377","637275739","1596346002","1896623296","1814433409"], - ["766517428","1263076038","358941187","2070217919","2108397185","1587546402"], - ["345404490","1065320570","1231275245","1037359122","1286389839","2070140848"], - ["746521574","835067673","311114030","1586400488","1406022058","1284151326"], - ["1857315969","431410759","825259098","1717904860","503708539","2097758215"], - ["1879479734","1863555039","2108235515","1833922769","1562707156","49484002"], - ["1366768987","1050390036","1491845132","666041968","74368055","1254335623"], - ["188857287","1161878039","1771805176","1457666227","1157868840","486461459"], - ["261764705","1577846886","1332322961","10423372","640027252","1086814656"], - ["111907709","542625019","2021749229","2013008690","523703611","1328833940"], - ["1270684472","1675989474","394214608","538100201","1984625073","1560563159"], - ["666555709","852557426","651115051","1878827907","953346499","619017191"], - ["747972907","149382079","1393306586","1394823957","960994901","536632180"], - ["80535893","229602380","1817483938","1455260088","484432","1869486290"], - ["443556931","253108261","1609174393","1245931188","752691602","1668543792"], - ["745497042","854686466","1834097777","642389535","1284043061","896553209"], - ["532777064","1491985134","200157005","1378855967","1159213374","1797221037"], - ["933463176","813761538","1124049829","1988347055","2115297439","1836576920"], - ["194436043","1437728625","43998833","786326005","1130428925","1424571033"], - ["2108272636","1410841489","753065553","2020187193","1644376367","670324352"], - ["1362448669","752702510","1740531646","47989265","588634","1940480814"], - ["439960422","528245604","179496898","235775013","59000527","1903150726"], - ["103605138","249162711","1971219628","1958189530","423278905","1318354885"], - ["321504059","1595801356","596911575","1361967073","459661104","599048233"], - ["1610552125","73166668","444776743","1820306524","1180674369","1570356756"], - ["1283703846","1024562975","958477092","1329464736","1758672211","899108631"], - ["713626137","904634570","1902566483","1938333063","447549083","703262660"], - ["1417291696","1717451368","354584524","832751684","930128006","1037860604"], - ["1618745108","79533863","301008038","2091942909","1221962725","1524945081"], - ["1490224031","349040760","393684137","484089443","1912848485","1790207999"], - ["1930411520","642009628","1138820074","31855314","1177766391","1913457637"], - ["618628623","139430131","904498895","925273128","2111653256","1012250155"]], - - "queried_values_1" : [["316772341","1526280133","663010112","224983897","510598760","1109503351"], - ["754832207","435790299","883623752","553207508","154784232","199176676"], - ["689603315","1763523007","1720552945","1983603154","367841669","319325418"], - ["1290247052","1120744584","193500372","294491115","951360807","891034447"]], - - "proof of work" : "43", - - "coeffs" : ["329725079","667313404","2083859876","1645693780"], - - "inner_commitment_0" : "45555755014923146766476222823122654194153582923386372098787021809602091298670", - - "inner_decommitment_0" : - ["5462985033728555575703006689913665598917262836577853690370070265073174719979", - "1439463537898028163672031322449473184361454650351831891678420729829634422052", - "3457992889375232257419443221173459860069710025424105705342440558952480618733", - "1560236756242436900530359200127475252176167690738804367783218449678388748008", - "44704035195642947074853484428073358106137689624692539529545008114744235429869", - "1823722396825684657353300460130959150013626547749061669644787096209462088081", - "36173961213090661587166983586606381085726394354099487131775975739295244107044", - "44982414287882497869227873977178787463294164719732705041425254748922385395330", - "45309152665928987599895709328729600158686240484613667382186228529382344291065", - "6727671460449390719545303211873485738734289534911560838410979450236173250575", - "20261836592905556803606993653450529841562498576647326913889589677099044443740", - "5092741370842944151567575040255811031129982749889308756570945791769905531735", - "25669118913473215056777887930215946759071827607816131638248897691395809170466"], - - "inner_evals_subset_0" : - [["1779738283","1440487135","622229563","2027928845"], - ["3310770","2003547458","1663490902","2105455978"], - ["824310523","1757518542","231582441","427507918"]], - - "inner_commitment_1" : "36809196688918736785151875655523363274066842107451407514854169291514772437712", - - "inner_decommitment_1" : - ["45836684762941279488847688946861562185184567552524644394596887832754938056979", - "19073399120361742411025793373596546929954434607182120440274945820526786807031", - "10149408627738268983458395386544516732232569888429637086606165543488301475287", - "1134561121197915913438467336074096605141057500441242329860673894339623319838", - "4117331370612733116088255919125911273783755080775248279590234839282346655209", - "10513576479117646185507147947418059201301955276798193127377481002238952180861", - "23342730846670478933566004300392843380760787769231395287565422468403932432226", - "18740447275326464369699774196183138103532792867824812360537906002797944316369", - "14414140432147963417180883803716587948502362440108737731008978168633980398219", - "12250577863311462987061070846596950168057013378171600204130942328863618210853"], - - "inner_evals_subset_1" : - [["1993244979","993480712","1202910330","51538179"], - ["1550507975","1444313548","117070947","1740854590"], - ["342709844","601149328","1436490544","1384381104"]], - - "inner_commitment_2" : "881819876414071785116043072319901261019430342891967010427739931379717752179", - - "inner_decommitment_2" : - ["3578061893060038632121836895066391994380789049478144565795630968916166958170", - "37159117331259094338569510749131708143302385613991325356583099983565421599794", - "37757292341028790385725896863222501721662517262004109998718347744172059822092", - "32039058093629437909283983486827138804118789892152561698628703527746673098335", - "11826129489683399268430621966121477905203281789801973078750629518404118320344", - "12209931412475259378582945309327673629248630080862144675946167350377528046777", - "23652788355891929123107872712785512551735081591713325931570432712643436668697"], - - "inner_evals_subset_2" : - [["1184374976","203117688","803515935","1781630737"], - ["716686867","138132852","2024080584","392488646"], - ["606958913","308986056","258114411","2075401741"]], - - "inner_commitment_3" : "12027868599227153144742193285247060272784688895537159385948117930165143367635", - - "inner_decommitment_3" : - ["46510393127320994984678433868779353751380750916123221099220042551249644407575", - "39205406309151711272632862117612882165913578213907083849769292791373274291092", - "1794620160371318211300608665903463515892826349097148444913826356018179256828", - "30764929521910551485060681298929856016753563990361023701791813032517030674324"], - - "inner_evals_subset_3" : - [["1682558787","129089003","784689440","491206249"], - ["2038210709","1600238918","655676259","1542271403"], - ["1014579052","1384080403","862591487","1941843578"]], - - "inner_commitment_4" : "14412699124489796400221638504502701429059474709940645969751865021865037702257", - - "inner_decommitment_4" : - ["19324767949149751902195880760061491860991545124249052461044586503970003688610"], - - "inner_evals_subset_4" : - [["683258805","1002722262","1583421272","1748673499"], - ["2101847208","689925082","1602280602","1942656531"], - ["1952987285","1995490213","2082219584","1620868519"]], - - "inner_commitment_5" : "7898658461322497542615494384418597990320440781916208640366283709415937814000", - - "inner_decommitment_5" : - [], - - "inner_evals_subset_5" : - [["1862940297","2014478284","1043383827","560191545"]] -} \ No newline at end of file diff --git a/Stwo_wrapper/crates/prover/src/constraint_framework/assert.rs b/Stwo_wrapper/crates/prover/src/constraint_framework/assert.rs deleted file mode 100644 index 9e5530b..0000000 --- a/Stwo_wrapper/crates/prover/src/constraint_framework/assert.rs +++ /dev/null @@ -1,84 +0,0 @@ -use num_traits::{One, Zero}; - -use super::EvalAtRow; -use crate::core::backend::{Backend, Column}; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; -use crate::core::pcs::TreeVec; -use crate::core::poly::circle::{CanonicCoset, CirclePoly}; -use crate::core::utils::circle_domain_order_to_coset_order; - -/// Evaluates expressions at a trace domain row, and asserts constraints. Mainly used for testing. -pub struct AssertEvaluator<'a> { - pub trace: &'a TreeVec>>, - pub col_index: TreeVec, - pub row: usize, -} -impl<'a> AssertEvaluator<'a> { - pub fn new(trace: &'a TreeVec>>, row: usize) -> Self { - Self { - trace, - col_index: TreeVec::new(vec![0; trace.len()]), - row, - } - } -} -impl<'a> EvalAtRow for AssertEvaluator<'a> { - type F = BaseField; - type EF = SecureField; - - fn next_interaction_mask( - &mut self, - interaction: usize, - offsets: [isize; N], - ) -> [Self::F; N] { - let col_index = self.col_index[interaction]; - self.col_index[interaction] += 1; - offsets.map(|off| { - // The mask row might wrap around the column size. - let col_size = self.trace[interaction][col_index].len() as isize; - self.trace[interaction][col_index] - [(self.row as isize + off).rem_euclid(col_size) as usize] - }) - } - - fn add_constraint(&mut self, constraint: G) - where - Self::EF: std::ops::Mul, - { - // Cast to SecureField. - let res = SecureField::one() * constraint; - // The constraint should be zero at the given row, since we are evaluating on the trace - // domain. - assert_eq!(res, SecureField::zero(), "row: {}", self.row); - } - - fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF { - SecureField::from_m31_array(values) - } -} - -pub fn assert_constraints( - trace_polys: &TreeVec>>, - trace_domain: CanonicCoset, - assert_func: impl Fn(AssertEvaluator<'_>), -) { - let traces = trace_polys.as_ref().map(|tree| { - tree.iter() - .map(|poly| { - circle_domain_order_to_coset_order( - &poly - .evaluate(trace_domain.circle_domain()) - .bit_reverse() - .values - .to_cpu(), - ) - }) - .collect() - }); - for row in 0..trace_domain.size() { - let eval = AssertEvaluator::new(&traces, row); - assert_func(eval); - } -} diff --git a/Stwo_wrapper/crates/prover/src/constraint_framework/component.rs b/Stwo_wrapper/crates/prover/src/constraint_framework/component.rs deleted file mode 100644 index c0d8319..0000000 --- a/Stwo_wrapper/crates/prover/src/constraint_framework/component.rs +++ /dev/null @@ -1,210 +0,0 @@ -use std::borrow::Cow; -use std::iter::zip; -use std::ops::Deref; - -use itertools::Itertools; -use tracing::{span, Level}; - -use super::{EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator}; -use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; -use crate::core::air::{Component, ComponentProver, Trace}; -use crate::core::backend::simd::column::VeryPackedSecureColumnByCoords; -use crate::core::backend::simd::m31::LOG_N_LANES; -use crate::core::backend::simd::very_packed_m31::{VeryPackedBaseField, LOG_N_VERY_PACKED_ELEMS}; -use crate::core::backend::simd::SimdBackend; -use crate::core::circle::CirclePoint; -use crate::core::constraints::coset_vanishing; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::FieldExpOps; -use crate::core::pcs::{TreeSubspan, TreeVec}; -use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; -use crate::core::poly::BitReversedOrder; -use crate::core::{utils, ColumnVec}; - -// TODO(andrew): Docs. -// TODO(andrew): Consider better location for this. -#[derive(Debug, Default)] -pub struct TraceLocationAllocator { - /// Mapping of tree index to next available column offset. - next_tree_offsets: TreeVec, -} - -impl TraceLocationAllocator { - fn next_for_structure(&mut self, structure: &TreeVec>) -> TreeVec { - if structure.len() > self.next_tree_offsets.len() { - self.next_tree_offsets.resize(structure.len(), 0); - } - - TreeVec::new( - zip(&mut *self.next_tree_offsets, &**structure) - .enumerate() - .map(|(tree_index, (offset, cols))| { - let col_start = *offset; - let col_end = col_start + cols.len(); - *offset = col_end; - TreeSubspan { - tree_index, - col_start, - col_end, - } - }) - .collect(), - ) - } -} - -/// A component defined solely in means of the constraints framework. -/// Implementing this trait introduces implementations for [`Component`] and [`ComponentProver`] for -/// the SIMD backend. -/// Note that the constraint framework only support components with columns of the same size. -pub trait FrameworkEval { - fn log_size(&self) -> u32; - - fn max_constraint_log_degree_bound(&self) -> u32; - - fn evaluate(&self, eval: E) -> E; -} - -pub struct FrameworkComponent { - eval: C, - trace_locations: TreeVec, -} - -impl FrameworkComponent { - pub fn new(provider: &mut TraceLocationAllocator, eval: E) -> Self { - let eval_tree_structure = eval.evaluate(InfoEvaluator::default()).mask_offsets; - let trace_locations = provider.next_for_structure(&eval_tree_structure); - Self { - eval, - trace_locations, - } - } -} - -impl Component for FrameworkComponent { - fn n_constraints(&self) -> usize { - self.eval.evaluate(InfoEvaluator::default()).n_constraints - } - - fn max_constraint_log_degree_bound(&self) -> u32 { - self.eval.max_constraint_log_degree_bound() - } - - fn trace_log_degree_bounds(&self) -> TreeVec> { - TreeVec::new( - self.eval - .evaluate(InfoEvaluator::default()) - .mask_offsets - .iter() - .map(|tree_masks| vec![self.eval.log_size(); tree_masks.len()]) - .collect(), - ) - } - - fn mask_points( - &self, - point: CirclePoint, - ) -> TreeVec>>> { - let info = self.eval.evaluate(InfoEvaluator::default()); - let trace_step = CanonicCoset::new(self.eval.log_size()).step(); - info.mask_offsets.map_cols(|col_mask| { - col_mask - .iter() - .map(|off| point + trace_step.mul_signed(*off).into_ef()) - .collect() - }) - } - - fn evaluate_constraint_quotients_at_point( - &self, - point: CirclePoint, - mask: &TreeVec>>, - evaluation_accumulator: &mut PointEvaluationAccumulator, - ) { - self.eval.evaluate(PointEvaluator::new( - mask.sub_tree(&self.trace_locations), - evaluation_accumulator, - coset_vanishing(CanonicCoset::new(self.eval.log_size()).coset, point).inverse(), - )); - } -} - -impl ComponentProver for FrameworkComponent { - fn evaluate_constraint_quotients_on_domain( - &self, - trace: &Trace<'_, SimdBackend>, - evaluation_accumulator: &mut DomainEvaluationAccumulator, - ) { - let eval_domain = CanonicCoset::new(self.max_constraint_log_degree_bound()).circle_domain(); - let trace_domain = CanonicCoset::new(self.eval.log_size()); - - let component_polys = trace.polys.sub_tree(&self.trace_locations); - let component_evals = trace.evals.sub_tree(&self.trace_locations); - - // Extend trace if necessary. - // TODO(spapini): Don't extend when eval_size < committed_size. Instead, pick a good - // subdomain. - let need_to_extend = component_evals - .iter() - .flatten() - .any(|c| c.domain != eval_domain); - let trace: TreeVec< - Vec>>, - > = if need_to_extend { - let _span = span!(Level::INFO, "Extension").entered(); - let twiddles = SimdBackend::precompute_twiddles(eval_domain.half_coset); - component_polys - .as_cols_ref() - .map_cols(|col| Cow::Owned(col.evaluate_with_twiddles(eval_domain, &twiddles))) - } else { - component_evals.clone().map_cols(|c| Cow::Borrowed(*c)) - }; - - // Denom inverses. - let log_expand = eval_domain.log_size() - trace_domain.log_size(); - let mut denom_inv = (0..1 << log_expand) - .map(|i| coset_vanishing(trace_domain.coset(), eval_domain.at(i)).inverse()) - .collect_vec(); - utils::bit_reverse(&mut denom_inv); - - // Accumulator. - let [mut accum] = - evaluation_accumulator.columns([(eval_domain.log_size(), self.n_constraints())]); - accum.random_coeff_powers.reverse(); - - let _span = span!(Level::INFO, "Constraint pointwise eval").entered(); - let col = unsafe { VeryPackedSecureColumnByCoords::transform_under_mut(accum.col) }; - - for vec_row in 0..(1 << (eval_domain.log_size() - LOG_N_LANES - LOG_N_VERY_PACKED_ELEMS)) { - let trace_cols = trace.as_cols_ref().map_cols(|c| c.as_ref()); - - // Evaluate constrains at row. - let eval = SimdDomainEvaluator::new( - &trace_cols, - vec_row, - &accum.random_coeff_powers, - trace_domain.log_size(), - eval_domain.log_size(), - ); - let row_res = self.eval.evaluate(eval).row_res; - - // Finalize row. - unsafe { - let denom_inv = VeryPackedBaseField::broadcast( - denom_inv[vec_row - >> (trace_domain.log_size() - LOG_N_LANES - LOG_N_VERY_PACKED_ELEMS)], - ); - col.set_packed(vec_row, col.packed_at(vec_row) + row_res * denom_inv) - } - } - } -} - -impl Deref for FrameworkComponent { - type Target = E; - - fn deref(&self) -> &E { - &self.eval - } -} diff --git a/Stwo_wrapper/crates/prover/src/constraint_framework/constant_columns.rs b/Stwo_wrapper/crates/prover/src/constraint_framework/constant_columns.rs deleted file mode 100644 index e57df28..0000000 --- a/Stwo_wrapper/crates/prover/src/constraint_framework/constant_columns.rs +++ /dev/null @@ -1,37 +0,0 @@ -use num_traits::One; - -use crate::core::backend::{Backend, Col, Column}; -use crate::core::fields::m31::BaseField; -use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; -use crate::core::poly::BitReversedOrder; -use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index}; - -/// Generates a column with a single one at the first position, and zeros elsewhere. -pub fn gen_is_first(log_size: u32) -> CircleEvaluation { - let mut col = Col::::zeros(1 << log_size); - col.set(0, BaseField::one()); - CircleEvaluation::new(CanonicCoset::new(log_size).circle_domain(), col) -} - -/// Generates a column with `1` at every `2^log_step` positions, `0` elsewhere, shifted by offset. -// TODO(andrew): Consider optimizing. Is a quotients of two coset_vanishing (use succinct rep for -// verifier). -pub fn gen_is_step_with_offset( - log_size: u32, - log_step: u32, - offset: usize, -) -> CircleEvaluation { - let mut col = Col::::zeros(1 << log_size); - - let size = 1 << log_size; - let step = 1 << log_step; - let step_offset = offset % step; - - for i in (step_offset..size).step_by(step) { - let circle_domain_index = coset_index_to_circle_domain_index(i, log_size); - let circle_domain_index_bit_rev = bit_reverse_index(circle_domain_index, log_size); - col.set(circle_domain_index_bit_rev, BaseField::one()); - } - - CircleEvaluation::new(CanonicCoset::new(log_size).circle_domain(), col) -} diff --git a/Stwo_wrapper/crates/prover/src/constraint_framework/info.rs b/Stwo_wrapper/crates/prover/src/constraint_framework/info.rs deleted file mode 100644 index 05da93f..0000000 --- a/Stwo_wrapper/crates/prover/src/constraint_framework/info.rs +++ /dev/null @@ -1,48 +0,0 @@ -use std::ops::Mul; - -use num_traits::One; - -use super::EvalAtRow; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::pcs::TreeVec; - -/// Collects information about the constraints. -/// This includes mask offsets and columns at each interaction, and the number of constraints. -#[derive(Default)] -pub struct InfoEvaluator { - pub mask_offsets: TreeVec>>, - pub n_constraints: usize, -} -impl InfoEvaluator { - pub fn new() -> Self { - Self::default() - } -} -impl EvalAtRow for InfoEvaluator { - type F = BaseField; - type EF = SecureField; - fn next_interaction_mask( - &mut self, - interaction: usize, - offsets: [isize; N], - ) -> [Self::F; N] { - // Check if requested a mask from a new interaction - if self.mask_offsets.len() <= interaction { - // Extend `mask_offsets` so that `interaction` is the last index. - self.mask_offsets.resize(interaction + 1, vec![]); - } - self.mask_offsets[interaction].push(offsets.into_iter().collect()); - [BaseField::one(); N] - } - fn add_constraint(&mut self, _constraint: G) - where - Self::EF: Mul, - { - self.n_constraints += 1; - } - - fn combine_ef(_values: [Self::F; 4]) -> Self::EF { - SecureField::one() - } -} diff --git a/Stwo_wrapper/crates/prover/src/constraint_framework/logup.rs b/Stwo_wrapper/crates/prover/src/constraint_framework/logup.rs deleted file mode 100644 index 696a7b9..0000000 --- a/Stwo_wrapper/crates/prover/src/constraint_framework/logup.rs +++ /dev/null @@ -1,315 +0,0 @@ -use std::ops::{Mul, Sub}; - -use itertools::Itertools; -use num_traits::{One, Zero}; - -use super::EvalAtRow; -use crate::core::backend::simd::column::SecureColumn; -use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; -use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum; -use crate::core::backend::simd::qm31::PackedSecureField; -use crate::core::backend::simd::SimdBackend; -use crate::core::backend::Column; -use crate::core::channel::Channel; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::secure_column::{SecureColumnByCoords, SECURE_EXTENSION_DEGREE}; -use crate::core::fields::FieldExpOps; -use crate::core::lookups::utils::Fraction; -use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; -use crate::core::poly::BitReversedOrder; -use crate::core::ColumnVec; - -/// Evaluates constraints for batched logups. -/// These constraint enforce the sum of multiplicity_i / (z + sum_j alpha^j * x_j) = claimed_sum. -/// BATCH_SIZE is the number of fractions to batch together. The degree of the resulting constraints -/// will be BATCH_SIZE + 1. -pub struct LogupAtRow { - /// The index of the interaction used for the cumulative sum columns. - pub interaction: usize, - /// Queue of fractions waiting to be batched together. - pub queue: [(E::EF, E::EF); BATCH_SIZE], - /// Number of fractions in the queue. - pub queue_size: usize, - /// A constant to subtract from each row, to make the totall sum of the last column zero. - /// In other words, claimed_sum / 2^log_size. - /// This is used to make the constraint uniform. - pub cumsum_shift: SecureField, - /// The evaluation of the last cumulative sum column. - pub prev_col_cumsum: E::EF, - is_finalized: bool, -} -impl LogupAtRow { - pub fn new(interaction: usize, claimed_sum: SecureField, log_size: u32) -> Self { - Self { - interaction, - queue: [(E::EF::zero(), E::EF::zero()); BATCH_SIZE], - queue_size: 0, - cumsum_shift: claimed_sum / BaseField::from_u32_unchecked(1 << log_size), - prev_col_cumsum: E::EF::zero(), - is_finalized: false, - } - } - pub fn push_lookup( - &mut self, - eval: &mut E, - numerator: E::EF, - values: &[E::F], - lookup_elements: &LookupElements, - ) { - let shifted_value = lookup_elements.combine(values); - self.push_frac(eval, numerator, shifted_value); - } - - pub fn push_frac(&mut self, eval: &mut E, numerator: E::EF, denominator: E::EF) { - if self.queue_size < BATCH_SIZE { - self.queue[self.queue_size] = (numerator, denominator); - self.queue_size += 1; - return; - } - - // Compute sum_i pi/qi over batch, as a fraction, num/denom. - let (num, denom) = self.fold_queue(); - - self.queue[0] = (numerator, denominator); - self.queue_size = 1; - - // Add a constraint that num / denom = diff. - let cur_cumsum = eval.next_extension_interaction_mask(self.interaction, [0])[0]; - let diff = cur_cumsum - self.prev_col_cumsum; - self.prev_col_cumsum = cur_cumsum; - eval.add_constraint(diff * denom - num); - } - - pub fn add_frac(&mut self, eval: &mut E, fraction: Fraction) { - // Add a constraint that num / denom = diff. - let cur_cumsum = eval.next_extension_interaction_mask(self.interaction, [0])[0]; - let diff = cur_cumsum - self.prev_col_cumsum; - self.prev_col_cumsum = cur_cumsum; - eval.add_constraint(diff * fraction.denominator - fraction.numerator); - } - - pub fn finalize(mut self, eval: &mut E) { - assert!(!self.is_finalized, "LogupAtRow was already finalized"); - let (num, denom) = self.fold_queue(); - - let [cur_cumsum, prev_row_cumsum] = - eval.next_extension_interaction_mask(self.interaction, [0, -1]); - - let diff = cur_cumsum - prev_row_cumsum - self.prev_col_cumsum; - // Instead of checking diff = num / denom, check diff = num / denom - cumsum_shift. - // This makes (num / denom - cumsum_shift) have sum zero, which makes the constraint - // uniform - apply on all rows. - let fixed_diff = diff + self.cumsum_shift; - - eval.add_constraint(fixed_diff * denom - num); - - self.is_finalized = true; - } - - fn fold_queue(&self) -> (E::EF, E::EF) { - self.queue[0..self.queue_size] - .iter() - .copied() - .fold((E::EF::zero(), E::EF::one()), |(p0, q0), (pi, qi)| { - (p0 * qi + pi * q0, qi * q0) - }) - } -} - -/// Ensures that the LogupAtRow is finalized. -/// LogupAtRow should be finalized exactly once. -impl Drop for LogupAtRow { - fn drop(&mut self) { - assert!(self.is_finalized, "LogupAtRow was not finalized"); - } -} - -/// Interaction elements for the logup protocol. -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct LookupElements { - pub z: SecureField, - pub alpha: SecureField, - alpha_powers: [SecureField; N], -} -impl LookupElements { - pub fn draw(channel: &mut impl Channel) -> Self { - let [z, alpha] = channel.draw_felts(2).try_into().unwrap(); - let mut cur = SecureField::one(); - let alpha_powers = std::array::from_fn(|_| { - let res = cur; - cur *= alpha; - res - }); - Self { - z, - alpha, - alpha_powers, - } - } - pub fn combine(&self, values: &[F]) -> EF - where - EF: Copy + Zero + From + From + Mul + Sub, - { - EF::from(values[0]) - + values[1..] - .iter() - .zip(self.alpha_powers.iter()) - .fold(EF::zero(), |acc, (&value, &power)| { - acc + EF::from(power) * value - }) - - EF::from(self.z) - } - // TODO(spapini): Try to remove this. - pub fn dummy() -> Self { - Self { - z: SecureField::one(), - alpha: SecureField::one(), - alpha_powers: [SecureField::one(); N], - } - } -} - -// SIMD backend generator for logup interaction trace. -pub struct LogupTraceGenerator { - log_size: u32, - /// Current allocated interaction columns. - trace: Vec>, - /// Denominator expressions (z + sum_i alpha^i * x_i) being generated for the current lookup. - denom: SecureColumn, - /// Preallocated buffer for the Inverses of the denominators. - denom_inv: SecureColumn, -} -impl LogupTraceGenerator { - pub fn new(log_size: u32) -> Self { - let trace = vec![]; - let denom = SecureColumn::zeros(1 << log_size); - let denom_inv = SecureColumn::zeros(1 << log_size); - Self { - log_size, - trace, - denom, - denom_inv, - } - } - - /// Allocate a new lookup column. - pub fn new_col(&mut self) -> LogupColGenerator<'_> { - let log_size = self.log_size; - LogupColGenerator { - gen: self, - numerator: SecureColumnByCoords::::zeros(1 << log_size), - } - } - - /// Finalize the trace. Returns the trace and the claimed sum of the last column. - pub fn finalize( - mut self, - ) -> ( - ColumnVec>, - SecureField, - ) { - // Compute claimed sum. - let mut last_col_coords = self.trace.pop().unwrap().columns; - let packed_sums: [PackedBaseField; SECURE_EXTENSION_DEGREE] = last_col_coords - .each_ref() - .map(|c| c.data.iter().copied().sum()); - let base_sums = packed_sums.map(|s| s.pointwise_sum()); - let claimed_sum = SecureField::from_m31_array(base_sums); - - // Shift the last column to make the sum zero. - let cumsum_shift = claimed_sum / BaseField::from_u32_unchecked(1 << self.log_size); - last_col_coords.iter_mut().enumerate().for_each(|(i, c)| { - c.data - .iter_mut() - .for_each(|x| *x -= PackedBaseField::broadcast(cumsum_shift.to_m31_array()[i])) - }); - - // Prefix sum the last column. - let coord_prefix_sum = last_col_coords.map(inclusive_prefix_sum); - self.trace.push(SecureColumnByCoords { - columns: coord_prefix_sum, - }); - - let trace = self - .trace - .into_iter() - .flat_map(|eval| { - eval.columns.map(|c| { - CircleEvaluation::::new( - CanonicCoset::new(self.log_size).circle_domain(), - c, - ) - }) - }) - .collect_vec(); - (trace, claimed_sum) - } -} - -/// Trace generator for a single lookup column. -pub struct LogupColGenerator<'a> { - gen: &'a mut LogupTraceGenerator, - /// Numerator expressions (i.e. multiplicities) being generated for the current lookup. - numerator: SecureColumnByCoords, -} -impl<'a> LogupColGenerator<'a> { - /// Write a fraction to the column at a row. - pub fn write_frac( - &mut self, - vec_row: usize, - numerator: PackedSecureField, - denom: PackedSecureField, - ) { - debug_assert!( - denom.to_array().iter().all(|x| *x != SecureField::zero()), - "{:?}", - ("denom at vec_row {} is zero {}", denom, vec_row) - ); - unsafe { - self.numerator.set_packed(vec_row, numerator); - *self.gen.denom.data.get_unchecked_mut(vec_row) = denom; - } - } - - /// Finalizes generating the column. - pub fn finalize_col(mut self) { - FieldExpOps::batch_inverse(&self.gen.denom.data, &mut self.gen.denom_inv.data); - - for vec_row in 0..(1 << (self.gen.log_size - LOG_N_LANES)) { - unsafe { - let value = self.numerator.packed_at(vec_row) - * *self.gen.denom_inv.data.get_unchecked(vec_row); - let prev_value = self - .gen - .trace - .last() - .map(|col| col.packed_at(vec_row)) - .unwrap_or_else(PackedSecureField::zero); - self.numerator.set_packed(vec_row, value + prev_value) - }; - } - - self.gen.trace.push(self.numerator) - } -} - -#[cfg(test)] -mod tests { - use num_traits::One; - - use super::LogupAtRow; - use crate::constraint_framework::InfoEvaluator; - use crate::core::fields::qm31::SecureField; - - #[test] - #[should_panic] - fn test_logup_not_finalized_panic() { - let mut logup = LogupAtRow::<2, InfoEvaluator>::new(1, SecureField::one(), 7); - logup.push_frac( - &mut InfoEvaluator::default(), - SecureField::one(), - SecureField::one(), - ); - } -} diff --git a/Stwo_wrapper/crates/prover/src/constraint_framework/mod.rs b/Stwo_wrapper/crates/prover/src/constraint_framework/mod.rs deleted file mode 100644 index 87069d3..0000000 --- a/Stwo_wrapper/crates/prover/src/constraint_framework/mod.rs +++ /dev/null @@ -1,97 +0,0 @@ -/// ! This module contains helpers to express and use constraints for components. -mod assert; -mod component; -pub mod constant_columns; -mod info; -pub mod logup; -mod point; -mod simd_domain; - -use std::array; -use std::fmt::Debug; -use std::ops::{Add, AddAssign, Mul, Neg, Sub}; - -pub use assert::{assert_constraints, AssertEvaluator}; -pub use component::{FrameworkComponent, FrameworkEval, TraceLocationAllocator}; -pub use info::InfoEvaluator; -use num_traits::{One, Zero}; -pub use point::PointEvaluator; -pub use simd_domain::SimdDomainEvaluator; - -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; -use crate::core::fields::FieldExpOps; - -/// A trait for evaluating expressions at some point or row. -pub trait EvalAtRow { - // TODO(spapini): Use a better trait for these, like 'Algebra' or something. - /// The field type holding values of columns for the component. These are the inputs to the - /// constraints. It might be [BaseField] packed types, or even [SecureField], when evaluating - /// the columns out of domain. - type F: FieldExpOps - + Copy - + Debug - + Zero - + Neg - + AddAssign - + AddAssign - + Add - + Sub - + Mul - + Add - + Mul - + Neg - + From; - - /// A field type representing the closure of `F` with multiplying by [SecureField]. Constraints - /// usually get multiplied by [SecureField] values for security. - type EF: One - + Copy - + Debug - + Zero - + From - + Neg - + AddAssign - + Add - + Sub - + Mul - + Add - + Mul - + Sub - + Mul - + From - + From; - - /// Returns the next mask value for the first interaction at offset 0. - fn next_trace_mask(&mut self) -> Self::F { - let [mask_item] = self.next_interaction_mask(0, [0]); - mask_item - } - - /// Returns the mask values of the given offsets for the next column in the interaction. - fn next_interaction_mask( - &mut self, - interaction: usize, - offsets: [isize; N], - ) -> [Self::F; N]; - - /// Returns the extension mask values of the given offsets for the next extension degree many - /// columns in the interaction. - fn next_extension_interaction_mask( - &mut self, - interaction: usize, - offsets: [isize; N], - ) -> [Self::EF; N] { - let res_col_major = array::from_fn(|_| self.next_interaction_mask(interaction, offsets)); - array::from_fn(|i| Self::combine_ef(res_col_major.map(|c| c[i]))) - } - - /// Adds a constraint to the component. - fn add_constraint(&mut self, constraint: G) - where - Self::EF: Mul; - - /// Combines 4 base field values into a single extension field value. - fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF; -} diff --git a/Stwo_wrapper/crates/prover/src/constraint_framework/point.rs b/Stwo_wrapper/crates/prover/src/constraint_framework/point.rs deleted file mode 100644 index 6c6f72f..0000000 --- a/Stwo_wrapper/crates/prover/src/constraint_framework/point.rs +++ /dev/null @@ -1,57 +0,0 @@ -use std::ops::Mul; - -use super::EvalAtRow; -use crate::core::air::accumulation::PointEvaluationAccumulator; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; -use crate::core::pcs::TreeVec; -use crate::core::ColumnVec; - -/// Evaluates expressions at a point out of domain. -pub struct PointEvaluator<'a> { - pub mask: TreeVec>>, - pub evaluation_accumulator: &'a mut PointEvaluationAccumulator, - pub col_index: Vec, - pub denom_inverse: SecureField, -} -impl<'a> PointEvaluator<'a> { - pub fn new( - mask: TreeVec>>, - evaluation_accumulator: &'a mut PointEvaluationAccumulator, - denom_inverse: SecureField, - ) -> Self { - let col_index = vec![0; mask.len()]; - Self { - mask, - evaluation_accumulator, - col_index, - denom_inverse, - } - } -} -impl<'a> EvalAtRow for PointEvaluator<'a> { - type F = SecureField; - type EF = SecureField; - - fn next_interaction_mask( - &mut self, - interaction: usize, - _offsets: [isize; N], - ) -> [Self::F; N] { - let col_index = self.col_index[interaction]; - self.col_index[interaction] += 1; - let mask = self.mask[interaction][col_index].clone(); - assert_eq!(mask.len(), N); - mask.try_into().unwrap() - } - fn add_constraint(&mut self, constraint: G) - where - Self::EF: Mul, - { - self.evaluation_accumulator - .accumulate(self.denom_inverse * constraint); - } - fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF { - SecureField::from_partial_evals(values) - } -} diff --git a/Stwo_wrapper/crates/prover/src/constraint_framework/simd_domain.rs b/Stwo_wrapper/crates/prover/src/constraint_framework/simd_domain.rs deleted file mode 100644 index ef3662a..0000000 --- a/Stwo_wrapper/crates/prover/src/constraint_framework/simd_domain.rs +++ /dev/null @@ -1,106 +0,0 @@ -use std::ops::Mul; - -use num_traits::Zero; - -use super::EvalAtRow; -use crate::core::backend::simd::column::VeryPackedBaseColumn; -use crate::core::backend::simd::m31::LOG_N_LANES; -use crate::core::backend::simd::very_packed_m31::{ - VeryPackedBaseField, VeryPackedSecureField, LOG_N_VERY_PACKED_ELEMS, -}; -use crate::core::backend::simd::SimdBackend; -use crate::core::backend::Column; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; -use crate::core::pcs::TreeVec; -use crate::core::poly::circle::CircleEvaluation; -use crate::core::poly::BitReversedOrder; -use crate::core::utils::offset_bit_reversed_circle_domain_index; - -/// Evaluates constraints at an evaluation domain points. -pub struct SimdDomainEvaluator<'a> { - pub trace_eval: - &'a TreeVec>>, - pub column_index_per_interaction: Vec, - /// The row index of the simd-vector row to evaluate the constraints at. - pub vec_row: usize, - pub random_coeff_powers: &'a [SecureField], - pub row_res: VeryPackedSecureField, - pub constraint_index: usize, - pub domain_log_size: u32, - pub eval_domain_log_size: u32, -} -impl<'a> SimdDomainEvaluator<'a> { - pub fn new( - trace_eval: &'a TreeVec>>, - vec_row: usize, - random_coeff_powers: &'a [SecureField], - domain_log_size: u32, - eval_log_size: u32, - ) -> Self { - Self { - trace_eval, - column_index_per_interaction: vec![0; trace_eval.len()], - vec_row, - random_coeff_powers, - row_res: VeryPackedSecureField::zero(), - constraint_index: 0, - domain_log_size, - eval_domain_log_size: eval_log_size, - } - } -} -impl<'a> EvalAtRow for SimdDomainEvaluator<'a> { - type F = VeryPackedBaseField; - type EF = VeryPackedSecureField; - - // TODO(spapini): Remove all boundary checks. - fn next_interaction_mask( - &mut self, - interaction: usize, - offsets: [isize; N], - ) -> [Self::F; N] { - let col_index = self.column_index_per_interaction[interaction]; - self.column_index_per_interaction[interaction] += 1; - offsets.map(|off| { - // If the offset is 0, we can just return the value directly from this row. - if off == 0 { - unsafe { - let col = &self - .trace_eval - .get_unchecked(interaction) - .get_unchecked(col_index) - .values; - let very_packed_col = VeryPackedBaseColumn::transform_under_ref(col); - return *very_packed_col.data.get_unchecked(self.vec_row); - }; - } - // Otherwise, we need to look up the value at the offset. - // Since the domain is bit-reversed circle domain ordered, we need to look up the value - // at the bit-reversed natural order index at an offset. - VeryPackedBaseField::from_array(std::array::from_fn(|i| { - let row_index = offset_bit_reversed_circle_domain_index( - (self.vec_row << (LOG_N_LANES + LOG_N_VERY_PACKED_ELEMS)) + i, - self.domain_log_size, - self.eval_domain_log_size, - off, - ); - self.trace_eval[interaction][col_index].at(row_index) - })) - }) - } - fn add_constraint(&mut self, constraint: G) - where - Self::EF: Mul, - { - self.row_res += - VeryPackedSecureField::broadcast(self.random_coeff_powers[self.constraint_index]) - * constraint; - self.constraint_index += 1; - } - - fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF { - VeryPackedSecureField::from_very_packed_m31s(values) - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/air/accumulation.rs b/Stwo_wrapper/crates/prover/src/core/air/accumulation.rs deleted file mode 100644 index 8fcf575..0000000 --- a/Stwo_wrapper/crates/prover/src/core/air/accumulation.rs +++ /dev/null @@ -1,297 +0,0 @@ -//! Accumulators for a random linear combination of circle polynomials. -//! Given N polynomials, u_0(P), ... u_{N-1}(P), and a random alpha, the combined polynomial is -//! defined as -//! f(p) = sum_i alpha^{N-1-i} u_i(P). - -use itertools::Itertools; -use tracing::{span, Level}; - -use crate::core::backend::{Backend, Col, Column, CpuBackend}; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::secure_column::SecureColumnByCoords; -use crate::core::fields::FieldOps; -use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, CirclePoly, SecureCirclePoly}; -use crate::core::poly::BitReversedOrder; -use crate::core::utils::generate_secure_powers; - -/// Accumulates N evaluations of u_i(P0) at a single point. -/// Computes f(P0), the combined polynomial at that point. -/// For n accumulated evaluations, the i'th evaluation is multiplied by alpha^(N-1-i). -pub struct PointEvaluationAccumulator { - random_coeff: SecureField, - accumulation: SecureField, -} - -impl PointEvaluationAccumulator { - /// Creates a new accumulator. - /// `random_coeff` should be a secure random field element, drawn from the channel. - pub fn new(random_coeff: SecureField) -> Self { - Self { - random_coeff, - accumulation: SecureField::default(), - } - } - - /// Accumulates u_i(P0), a polynomial evaluation at a P0 in reverse order. - pub fn accumulate(&mut self, evaluation: SecureField) { - self.accumulation = self.accumulation * self.random_coeff + evaluation; - } - - pub fn finalize(self) -> SecureField { - self.accumulation - } -} - -// TODO(ShaharS), rename terminology to constraints instead of columns. -/// Accumulates evaluations of u_i(P), each at an evaluation domain of the size of that polynomial. -/// Computes the coefficients of f(P). -pub struct DomainEvaluationAccumulator { - random_coeff_powers: Vec, - /// Accumulated evaluations for each log_size. - /// Each `sub_accumulation` holds the sum over all columns i of that log_size, of - /// `evaluation_i * alpha^(N - 1 - i)` - /// where `N` is the total number of evaluations. - sub_accumulations: Vec>>, -} - -impl DomainEvaluationAccumulator { - /// Creates a new accumulator. - /// `random_coeff` should be a secure random field element, drawn from the channel. - /// `max_log_size` is the maximum log_size of the accumulated evaluations. - pub fn new(random_coeff: SecureField, max_log_size: u32, total_columns: usize) -> Self { - let max_log_size = max_log_size as usize; - Self { - random_coeff_powers: generate_secure_powers(random_coeff, total_columns), - sub_accumulations: (0..(max_log_size + 1)).map(|_| None).collect(), - } - } - - /// Gets accumulators for some sizes. - /// `n_cols_per_size` is an array of pairs (log_size, n_cols). - /// For each entry, a [ColumnAccumulator] is returned, expecting to accumulate `n_cols` - /// evaluations of size `log_size`. - /// The array size, `N`, is the number of different sizes. - pub fn columns( - &mut self, - n_cols_per_size: [(u32, usize); N], - ) -> [ColumnAccumulator<'_, B>; N] { - self.sub_accumulations - .get_many_mut(n_cols_per_size.map(|(log_size, _)| log_size as usize)) - .unwrap_or_else(|e| panic!("invalid log_sizes: {}", e)) - .into_iter() - .zip(n_cols_per_size) - .map(|(col, (log_size, n_cols))| { - let random_coeffs = self - .random_coeff_powers - .split_off(self.random_coeff_powers.len() - n_cols); - ColumnAccumulator { - random_coeff_powers: random_coeffs, - col: col.get_or_insert_with(|| SecureColumnByCoords::zeros(1 << log_size)), - } - }) - .collect_vec() - .try_into() - .unwrap_or_else(|_| unreachable!()) - } - - /// Returns the log size of the resulting polynomial. - pub fn log_size(&self) -> u32 { - (self.sub_accumulations.len() - 1) as u32 - } -} - -pub trait AccumulationOps: FieldOps + Sized { - /// Accumulates other into column: - /// column = column + other. - fn accumulate(column: &mut SecureColumnByCoords, other: &SecureColumnByCoords); -} - -impl DomainEvaluationAccumulator { - /// Computes f(P) as coefficients. - pub fn finalize(self) -> SecureCirclePoly { - assert_eq!( - self.random_coeff_powers.len(), - 0, - "not all random coefficients were used" - ); - let log_size = self.log_size(); - let _span = span!(Level::INFO, "Constraints interpolation").entered(); - let mut cur_poly: Option> = None; - let twiddles = B::precompute_twiddles( - CanonicCoset::new(self.log_size()) - .circle_domain() - .half_coset, - ); - - for (log_size, values) in self.sub_accumulations.into_iter().enumerate().skip(1) { - let Some(mut values) = values else { - continue; - }; - if let Some(prev_poly) = cur_poly { - let eval = SecureColumnByCoords { - columns: prev_poly.0.map(|c| { - c.evaluate_with_twiddles( - CanonicCoset::new(log_size as u32).circle_domain(), - &twiddles, - ) - .values - }), - }; - B::accumulate(&mut values, &eval); - } - cur_poly = Some(SecureCirclePoly(values.columns.map(|c| { - CircleEvaluation::::new( - CanonicCoset::new(log_size as u32).circle_domain(), - c, - ) - .interpolate_with_twiddles(&twiddles) - }))); - } - cur_poly.unwrap_or_else(|| { - SecureCirclePoly(std::array::from_fn(|_| { - CirclePoly::new(Col::::zeros(1 << log_size)) - })) - }) - } -} - -/// A domain accumulator for polynomials of a single size. -pub struct ColumnAccumulator<'a, B: Backend> { - pub random_coeff_powers: Vec, - pub col: &'a mut SecureColumnByCoords, -} -impl<'a> ColumnAccumulator<'a, CpuBackend> { - pub fn accumulate(&mut self, index: usize, evaluation: SecureField) { - let val = self.col.at(index) + evaluation; - self.col.set(index, val); - } -} - -#[cfg(test)] -mod tests { - use std::array; - - use num_traits::Zero; - use rand::rngs::SmallRng; - use rand::{Rng, SeedableRng}; - - use super::*; - use crate::core::backend::cpu::CpuCircleEvaluation; - use crate::core::circle::CirclePoint; - use crate::core::fields::m31::{M31, P}; - use crate::qm31; - - #[test] - fn test_point_evaluation_accumulator() { - // Generate a vector of random sizes with a constant seed. - let mut rng = SmallRng::seed_from_u64(0); - const MAX_LOG_SIZE: u32 = 10; - const MASK: u32 = P; - let log_sizes = (0..100) - .map(|_| rng.gen_range(4..MAX_LOG_SIZE)) - .collect::>(); - - // Generate random evaluations. - let evaluations = log_sizes - .iter() - .map(|_| M31::from_u32_unchecked(rng.gen::() & MASK)) - .collect::>(); - let alpha = qm31!(2, 3, 4, 5); - - // Use accumulator. - let mut accumulator = PointEvaluationAccumulator::new(alpha); - for (_, evaluation) in log_sizes.iter().zip(evaluations.iter()) { - accumulator.accumulate((*evaluation).into()); - } - let accumulator_res = accumulator.finalize(); - - // Use direct computation. - let mut res = SecureField::default(); - for evaluation in evaluations.iter() { - res = res * alpha + *evaluation; - } - - assert_eq!(accumulator_res, res); - } - - #[test] - fn test_domain_evaluation_accumulator() { - // Generate a vector of random sizes with a constant seed. - let mut rng = SmallRng::seed_from_u64(0); - const LOG_SIZE_MIN: u32 = 4; - const LOG_SIZE_BOUND: u32 = 10; - const MASK: u32 = P; - let mut log_sizes = (0..100) - .map(|_| rng.gen_range(LOG_SIZE_MIN..LOG_SIZE_BOUND)) - .collect::>(); - log_sizes.sort(); - - // Generate random evaluations. - let evaluations = log_sizes - .iter() - .map(|log_size| { - (0..(1 << *log_size)) - .map(|_| M31::from_u32_unchecked(rng.gen::() & MASK)) - .collect::>() - }) - .collect::>(); - let alpha = qm31!(2, 3, 4, 5); - - // Use accumulator. - let mut accumulator = DomainEvaluationAccumulator::::new( - alpha, - LOG_SIZE_BOUND, - evaluations.len(), - ); - let n_cols_per_size: [(u32, usize); (LOG_SIZE_BOUND - LOG_SIZE_MIN) as usize] = - array::from_fn(|i| { - let current_log_size = LOG_SIZE_MIN + i as u32; - let n_cols = log_sizes - .iter() - .copied() - .filter(|&log_size| log_size == current_log_size) - .count(); - (current_log_size, n_cols) - }); - let mut cols = accumulator.columns(n_cols_per_size); - let mut eval_chunk_offset = 0; - for (log_size, n_cols) in n_cols_per_size.iter() { - for index in 0..(1 << log_size) { - let mut val = SecureField::zero(); - for (eval_index, (col_log_size, evaluation)) in - log_sizes.iter().zip(evaluations.iter()).enumerate() - { - if *log_size != *col_log_size { - continue; - } - - // The random coefficient powers chunk is in regular order. - let random_coeff_chunk = - &cols[(log_size - LOG_SIZE_MIN) as usize].random_coeff_powers; - val += random_coeff_chunk - [random_coeff_chunk.len() - 1 - (eval_index - eval_chunk_offset)] - * evaluation[index]; - } - cols[(log_size - LOG_SIZE_MIN) as usize].accumulate(index, val); - } - eval_chunk_offset += n_cols; - } - let accumulator_poly = accumulator.finalize(); - - // Pick an arbitrary sample point. - let point = CirclePoint::::get_point(98989892); - let accumulator_res = accumulator_poly.eval_at_point(point); - - // Use direct computation. - let mut res = SecureField::default(); - for (log_size, values) in log_sizes.into_iter().zip(evaluations) { - res = res * alpha - + CpuCircleEvaluation::new(CanonicCoset::new(log_size).circle_domain(), values) - .interpolate() - .eval_at_point(point); - } - - assert_eq!(accumulator_res, res); - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/air/components.rs b/Stwo_wrapper/crates/prover/src/core/air/components.rs deleted file mode 100644 index a7e0129..0000000 --- a/Stwo_wrapper/crates/prover/src/core/air/components.rs +++ /dev/null @@ -1,80 +0,0 @@ -use itertools::Itertools; - -use super::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; -use super::{Component, ComponentProver, Trace}; -use crate::core::backend::Backend; -use crate::core::circle::CirclePoint; -use crate::core::fields::qm31::SecureField; -use crate::core::pcs::TreeVec; -use crate::core::poly::circle::SecureCirclePoly; -use crate::core::ColumnVec; - -pub struct Components<'a>(pub Vec<&'a dyn Component>); - -impl<'a> Components<'a> { - pub fn composition_log_degree_bound(&self) -> u32 { - self.0 - .iter() - .map(|component| component.max_constraint_log_degree_bound()) - .max() - .unwrap() - } - - pub fn mask_points( - &self, - point: CirclePoint, - ) -> TreeVec>>> { - TreeVec::concat_cols(self.0.iter().map(|component| component.mask_points(point))) - } - - pub fn eval_composition_polynomial_at_point( - &self, - point: CirclePoint, - mask_values: &TreeVec>>, - random_coeff: SecureField, - ) -> SecureField { - //accumulator for the random linear comination over powers of random_coeff - let mut evaluation_accumulator = PointEvaluationAccumulator::new(random_coeff); - - for component in &self.0 { - component.evaluate_constraint_quotients_at_point( - point, - mask_values, - &mut evaluation_accumulator, - ) - } - evaluation_accumulator.finalize() - } - - pub fn column_log_sizes(&self) -> TreeVec> { - TreeVec::concat_cols( - self.0 - .iter() - .map(|component| component.trace_log_degree_bounds()), - ) - } -} - -pub struct ComponentProvers<'a, B: Backend>(pub Vec<&'a dyn ComponentProver>); - -impl<'a, B: Backend> ComponentProvers<'a, B> { - pub fn components(&self) -> Components<'_> { - Components(self.0.iter().map(|c| *c as &dyn Component).collect_vec()) - } - pub fn compute_composition_polynomial( - &self, - random_coeff: SecureField, - trace: &Trace<'_, B>, - ) -> SecureCirclePoly { - let total_constraints: usize = self.0.iter().map(|c| c.n_constraints()).sum(); - let mut accumulator = DomainEvaluationAccumulator::new( - random_coeff, - self.components().composition_log_degree_bound(), - total_constraints, - ); - for component in &self.0 { - component.evaluate_constraint_quotients_on_domain(trace, &mut accumulator) - } - accumulator.finalize() - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/air/mask.rs b/Stwo_wrapper/crates/prover/src/core/air/mask.rs deleted file mode 100644 index e2748a6..0000000 --- a/Stwo_wrapper/crates/prover/src/core/air/mask.rs +++ /dev/null @@ -1,91 +0,0 @@ -use std::collections::HashSet; -use std::vec; - -use itertools::Itertools; - -use crate::core::circle::CirclePoint; -use crate::core::fields::qm31::SecureField; -use crate::core::poly::circle::CanonicCoset; -use crate::core::ColumnVec; - -/// Mask holds a vector with an entry for each column. -/// Each entry holds a list of mask items, which are the offsets of the mask at that column. -type Mask = ColumnVec>; - -/// Returns the same point for each mask item. -/// Should be used where all the mask items has no shift from the constraint point. -pub fn fixed_mask_points( - mask: &Mask, - point: CirclePoint, -) -> ColumnVec>> { - assert_eq!( - mask.iter() - .flat_map(|mask_entry| mask_entry.iter().collect::>()) - .collect::>() - .into_iter() - .collect_vec(), - vec![&0] - ); - mask.iter() - .map(|mask_entry| mask_entry.iter().map(|_| point).collect()) - .collect() -} - -/// For each mask item returns the point shifted by the domain initial point of the column. -/// Should be used where the mask items are shifted from the constraint point. -pub fn shifted_mask_points( - mask: &Mask, - domains: &[CanonicCoset], - point: CirclePoint, -) -> ColumnVec>> { - mask.iter() - .zip(domains.iter()) - .map(|(mask_entry, domain)| { - mask_entry - .iter() - .map(|mask_item| point + domain.at(*mask_item).into_ef()) - .collect() - }) - .collect() -} - -#[cfg(test)] -mod tests { - use crate::core::air::mask::{fixed_mask_points, shifted_mask_points}; - use crate::core::circle::CirclePoint; - use crate::core::poly::circle::CanonicCoset; - - #[test] - fn test_mask_fixed_points() { - let mask = vec![vec![0], vec![0]]; - let constraint_point = CirclePoint::get_point(1234); - - let points = fixed_mask_points(&mask, constraint_point); - - assert_eq!(points.len(), 2); - assert_eq!(points[0].len(), 1); - assert_eq!(points[1].len(), 1); - assert_eq!(points[0][0], constraint_point); - assert_eq!(points[1][0], constraint_point); - } - - #[test] - fn test_mask_shifted_points() { - let mask = vec![vec![0, 1], vec![0, 1, 2]]; - let constraint_point = CirclePoint::get_point(1234); - let domains = (0..mask.len() as u32) - .map(|i| CanonicCoset::new(7 + i)) - .collect::>(); - - let points = shifted_mask_points(&mask, &domains, constraint_point); - - assert_eq!(points.len(), 2); - assert_eq!(points[0].len(), 2); - assert_eq!(points[1].len(), 3); - assert_eq!(points[0][0], constraint_point + domains[0].at(0).into_ef()); - assert_eq!(points[0][1], constraint_point + domains[0].at(1).into_ef()); - assert_eq!(points[1][0], constraint_point + domains[1].at(0).into_ef()); - assert_eq!(points[1][1], constraint_point + domains[1].at(1).into_ef()); - assert_eq!(points[1][2], constraint_point + domains[1].at(2).into_ef()); - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/air/mod.rs b/Stwo_wrapper/crates/prover/src/core/air/mod.rs deleted file mode 100644 index fcdd4d5..0000000 --- a/Stwo_wrapper/crates/prover/src/core/air/mod.rs +++ /dev/null @@ -1,76 +0,0 @@ -pub use components::{ComponentProvers, Components}; - -use self::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; -use super::backend::Backend; -use super::circle::CirclePoint; -use super::fields::m31::BaseField; -use super::fields::qm31::SecureField; -use super::pcs::TreeVec; -use super::poly::circle::{CircleEvaluation, CirclePoly}; -use super::poly::BitReversedOrder; -use super::ColumnVec; - -pub mod accumulation; -mod components; -pub mod mask; - -/// Arithmetic Intermediate Representation (AIR). -/// An Air instance is assumed to already contain all the information needed to -/// evaluate the constraints. -/// For instance, all interaction elements are assumed to be present in it. -/// Therefore, an AIR is generated only after the initial trace commitment phase. -// TODO(spapini): consider renaming this struct. -pub trait Air { - fn components(&self) -> Vec<&dyn Component>; -} - -pub trait AirProver: Air { - fn component_provers(&self) -> Vec<&dyn ComponentProver>; -} - -/// A component is a set of trace columns of various sizes along with a set of -/// constraints on them. -pub trait Component { - fn n_constraints(&self) -> usize; - - fn max_constraint_log_degree_bound(&self) -> u32; - - /// Returns the degree bounds of each trace column. The returned TreeVec should be of size - /// `n_interaction_phases`. - fn trace_log_degree_bounds(&self) -> TreeVec>; - - /// Returns the mask points for each trace column. The returned TreeVec should be of size - /// `n_interaction_phases`. - fn mask_points( - &self, - point: CirclePoint, - ) -> TreeVec>>>; - - /// Evaluates the constraint quotients combination of the component at a point. - fn evaluate_constraint_quotients_at_point( - &self, - point: CirclePoint, - mask: &TreeVec>>, - evaluation_accumulator: &mut PointEvaluationAccumulator, - ); -} - -pub trait ComponentProver: Component { - /// Evaluates the constraint quotients of the component on the evaluation domain. - /// Accumulates quotients in `evaluation_accumulator`. - fn evaluate_constraint_quotients_on_domain( - &self, - trace: &Trace<'_, B>, - evaluation_accumulator: &mut DomainEvaluationAccumulator, - ); -} - -/// The set of polynomials that make up the trace. -/// -/// Each polynomial is stored both in a coefficients, and evaluations form (for efficiency) -pub struct Trace<'a, B: Backend> { - /// Polynomials for each column. - pub polys: TreeVec>>, - /// Evaluations for each column (evaluated on their commitment domains). - pub evals: TreeVec>>, -} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/cpu/accumulation.rs b/Stwo_wrapper/crates/prover/src/core/backend/cpu/accumulation.rs deleted file mode 100644 index 63a49bf..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/cpu/accumulation.rs +++ /dev/null @@ -1,12 +0,0 @@ -use super::CpuBackend; -use crate::core::air::accumulation::AccumulationOps; -use crate::core::fields::secure_column::SecureColumnByCoords; - -impl AccumulationOps for CpuBackend { - fn accumulate(column: &mut SecureColumnByCoords, other: &SecureColumnByCoords) { - for i in 0..column.len() { - let res_coeff = column.at(i) + other.at(i); - column.set(i, res_coeff); - } - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/cpu/blake2s.rs b/Stwo_wrapper/crates/prover/src/core/backend/cpu/blake2s.rs deleted file mode 100644 index a87a5ae..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/cpu/blake2s.rs +++ /dev/null @@ -1,24 +0,0 @@ -use itertools::Itertools; - -use crate::core::backend::CpuBackend; -use crate::core::fields::m31::BaseField; -use crate::core::vcs::blake2_hash::Blake2sHash; -use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher; -use crate::core::vcs::ops::{MerkleHasher, MerkleOps}; - -impl MerkleOps for CpuBackend { - fn commit_on_layer( - log_size: u32, - prev_layer: Option<&Vec>, - columns: &[&Vec], - ) -> Vec { - (0..(1 << log_size)) - .map(|i| { - Blake2sMerkleHasher::hash_node( - prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])), - &columns.iter().map(|column| column[i]).collect_vec(), - ) - }) - .collect() - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/cpu/circle.rs b/Stwo_wrapper/crates/prover/src/core/backend/cpu/circle.rs deleted file mode 100644 index c37ffe2..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/cpu/circle.rs +++ /dev/null @@ -1,376 +0,0 @@ -use num_traits::Zero; - -use super::CpuBackend; -use crate::core::backend::{Col, ColumnOps}; -use crate::core::circle::{CirclePoint, Coset}; -use crate::core::fft::{butterfly, ibutterfly}; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::{ExtensionOf, FieldExpOps}; -use crate::core::poly::circle::{ - CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly, PolyOps, -}; -use crate::core::poly::twiddles::TwiddleTree; -use crate::core::poly::utils::{domain_line_twiddles_from_tree, fold}; -use crate::core::poly::BitReversedOrder; -use crate::core::utils::{bit_reverse, coset_order_to_circle_domain_order}; - -impl PolyOps for CpuBackend { - type Twiddles = Vec; - - fn new_canonical_ordered( - coset: CanonicCoset, - values: Col, - ) -> CircleEvaluation { - let domain = coset.circle_domain(); - assert_eq!(values.len(), domain.size()); - let mut new_values = coset_order_to_circle_domain_order(&values); - CpuBackend::bit_reverse_column(&mut new_values); - CircleEvaluation::new(domain, new_values) - } - - fn interpolate( - eval: CircleEvaluation, - twiddles: &TwiddleTree, - ) -> CirclePoly { - assert!(eval.domain.half_coset.is_doubling_of(twiddles.root_coset)); - - let mut values = eval.values; - - if eval.domain.log_size() == 1 { - let y = eval.domain.half_coset.initial.y; - let n = BaseField::from(2); - let yn_inv = (y * n).inverse(); - let y_inv = yn_inv * n; - let n_inv = yn_inv * y; - let (mut v0, mut v1) = (values[0], values[1]); - ibutterfly(&mut v0, &mut v1, y_inv); - return CirclePoly::new(vec![v0 * n_inv, v1 * n_inv]); - } - - if eval.domain.log_size() == 2 { - let CirclePoint { x, y } = eval.domain.half_coset.initial; - let n = BaseField::from(4); - let xyn_inv = (x * y * n).inverse(); - let x_inv = xyn_inv * y * n; - let y_inv = xyn_inv * x * n; - let n_inv = xyn_inv * x * y; - let (mut v0, mut v1, mut v2, mut v3) = (values[0], values[1], values[2], values[3]); - ibutterfly(&mut v0, &mut v1, y_inv); - ibutterfly(&mut v2, &mut v3, -y_inv); - ibutterfly(&mut v0, &mut v2, x_inv); - ibutterfly(&mut v1, &mut v3, x_inv); - return CirclePoly::new(vec![v0 * n_inv, v1 * n_inv, v2 * n_inv, v3 * n_inv]); - } - - let line_twiddles = domain_line_twiddles_from_tree(eval.domain, &twiddles.itwiddles); - let circle_twiddles = circle_twiddles_from_line_twiddles(line_twiddles[0]); - - for (h, t) in circle_twiddles.enumerate() { - fft_layer_loop(&mut values, 0, h, t, ibutterfly); - } - for (layer, layer_twiddles) in line_twiddles.into_iter().enumerate() { - for (h, &t) in layer_twiddles.iter().enumerate() { - fft_layer_loop(&mut values, layer + 1, h, t, ibutterfly); - } - } - - // Divide all values by 2^log_size. - let inv = BaseField::from_u32_unchecked(eval.domain.size() as u32).inverse(); - for val in &mut values { - *val *= inv; - } - - CirclePoly::new(values) - } - - fn eval_at_point(poly: &CirclePoly, point: CirclePoint) -> SecureField { - if poly.log_size() == 0 { - return poly.coeffs[0].into(); - } - - let mut mappings = vec![point.y]; - let mut x = point.x; - for _ in 1..poly.log_size() { - mappings.push(x); - x = CirclePoint::double_x(x); - } - mappings.reverse(); - - fold(&poly.coeffs, &mappings) - } - - fn extend(poly: &CirclePoly, log_size: u32) -> CirclePoly { - assert!(log_size >= poly.log_size()); - let mut coeffs = Vec::with_capacity(1 << log_size); - coeffs.extend_from_slice(&poly.coeffs); - coeffs.resize(1 << log_size, BaseField::zero()); - CirclePoly::new(coeffs) - } - - fn evaluate( - poly: &CirclePoly, - domain: CircleDomain, - twiddles: &TwiddleTree, - ) -> CircleEvaluation { - assert!(domain.half_coset.is_doubling_of(twiddles.root_coset)); - - let mut values = poly.extend(domain.log_size()).coeffs; - - if domain.log_size() == 1 { - let (mut v0, mut v1) = (values[0], values[1]); - butterfly(&mut v0, &mut v1, domain.half_coset.initial.y); - return CircleEvaluation::new(domain, vec![v0, v1]); - } - - if domain.log_size() == 2 { - let (mut v0, mut v1, mut v2, mut v3) = (values[0], values[1], values[2], values[3]); - let CirclePoint { x, y } = domain.half_coset.initial; - butterfly(&mut v0, &mut v2, x); - butterfly(&mut v1, &mut v3, x); - butterfly(&mut v0, &mut v1, y); - butterfly(&mut v2, &mut v3, -y); - return CircleEvaluation::new(domain, vec![v0, v1, v2, v3]); - } - - let line_twiddles = domain_line_twiddles_from_tree(domain, &twiddles.twiddles); - let circle_twiddles = circle_twiddles_from_line_twiddles(line_twiddles[0]); - - for (layer, layer_twiddles) in line_twiddles.iter().enumerate().rev() { - for (h, &t) in layer_twiddles.iter().enumerate() { - fft_layer_loop(&mut values, layer + 1, h, t, butterfly); - } - } - for (h, t) in circle_twiddles.enumerate() { - fft_layer_loop(&mut values, 0, h, t, butterfly); - } - - CircleEvaluation::new(domain, values) - } - - fn precompute_twiddles(mut coset: Coset) -> TwiddleTree { - const CHUNK_LOG_SIZE: usize = 12; - const CHUNK_SIZE: usize = 1 << CHUNK_LOG_SIZE; - - let root_coset = coset; - let mut twiddles = Vec::with_capacity(coset.size()); - for _ in 0..coset.log_size() { - let i0 = twiddles.len(); - twiddles.extend( - coset - .iter() - .take(coset.size() / 2) - .map(|p| p.x) - .collect::>(), - ); - bit_reverse(&mut twiddles[i0..]); - coset = coset.double(); - } - twiddles.push(1.into()); - - // Inverse twiddles. - // Fallback to the non-chunked version if the domain is not big enough. - if CHUNK_SIZE > coset.size() { - let itwiddles = twiddles.iter().map(|&t| t.inverse()).collect(); - return TwiddleTree { - root_coset, - twiddles, - itwiddles, - }; - } - - let mut itwiddles = vec![BaseField::zero(); twiddles.len()]; - twiddles - .array_chunks::() - .zip(itwiddles.array_chunks_mut::()) - .for_each(|(src, dst)| { - BaseField::batch_inverse(src, dst); - }); - - TwiddleTree { - root_coset, - twiddles, - itwiddles, - } - } -} - -fn fft_layer_loop( - values: &mut [BaseField], - i: usize, - h: usize, - t: BaseField, - butterfly_fn: impl Fn(&mut BaseField, &mut BaseField, BaseField), -) { - for l in 0..(1 << i) { - let idx0 = (h << (i + 1)) + l; - let idx1 = idx0 + (1 << i); - let (mut val0, mut val1) = (values[idx0], values[idx1]); - butterfly_fn(&mut val0, &mut val1, t); - (values[idx0], values[idx1]) = (val0, val1); - } -} - -/// Computes the circle twiddles layer (layer 0) from the first line twiddles layer (layer 1). -/// -/// Only works for line twiddles generated from a domain with size `>4`. -fn circle_twiddles_from_line_twiddles( - first_line_twiddles: &[BaseField], -) -> impl Iterator + '_ { - // The twiddles for layer 0 can be computed from the twiddles for layer 1. - // Since the twiddles are bit reversed, we consider the circle domain in bit reversed order. - // Each consecutive 4 points in the bit reversed order of a coset form a circle coset of size 4. - // A circle coset of size 4 in bit reversed order looks like this: - // [(x, y), (-x, -y), (y, -x), (-y, x)] - // Note: This relation is derived from the fact that `M31_CIRCLE_GEN`.repeated_double(ORDER / 4) - // == (-1,0), and not (0,1). (0,1) would yield another relation. - // The twiddles for layer 0 are the y coordinates: - // [y, -y, -x, x] - // The twiddles for layer 1 in bit reversed order are the x coordinates of the even indices - // points: - // [x, y] - // Works also for inverse of the twiddles. - first_line_twiddles - .iter() - .array_chunks() - .flat_map(|[&x, &y]| [y, -y, -x, x]) -} - -impl, EvalOrder> IntoIterator - for CircleEvaluation -{ - type Item = F; - type IntoIter = std::vec::IntoIter; - - /// Creates a consuming iterator over the evaluations. - /// - /// Evaluations are returned in the same order as elements of the domain. - fn into_iter(self) -> Self::IntoIter { - self.values.into_iter() - } -} - -#[cfg(test)] -mod tests { - use std::iter::zip; - - use num_traits::One; - - use crate::core::backend::cpu::CpuCirclePoly; - use crate::core::circle::CirclePoint; - use crate::core::fields::m31::BaseField; - use crate::core::fields::qm31::SecureField; - use crate::core::poly::circle::CanonicCoset; - - #[test] - fn test_eval_at_point_with_4_coeffs() { - // Represents the polynomial `1 + 2y + 3x + 4xy`. - // Note coefficients are passed in bit reversed order. - let poly = CpuCirclePoly::new([1, 3, 2, 4].map(BaseField::from).to_vec()); - let x = BaseField::from(5).into(); - let y = BaseField::from(8).into(); - - let eval = poly.eval_at_point(CirclePoint { x, y }); - - assert_eq!( - eval, - poly.coeffs[0] + poly.coeffs[1] * y + poly.coeffs[2] * x + poly.coeffs[3] * x * y - ); - } - - #[test] - fn test_eval_at_point_with_2_coeffs() { - // Represents the polynomial `1 + 2y`. - let poly = CpuCirclePoly::new(vec![BaseField::from(1), BaseField::from(2)]); - let x = BaseField::from(5).into(); - let y = BaseField::from(8).into(); - - let eval = poly.eval_at_point(CirclePoint { x, y }); - - assert_eq!(eval, poly.coeffs[0] + poly.coeffs[1] * y); - } - - #[test] - fn test_eval_at_point_with_1_coeff() { - // Represents the polynomial `1`. - let poly = CpuCirclePoly::new(vec![BaseField::one()]); - let x = BaseField::from(5).into(); - let y = BaseField::from(8).into(); - - let eval = poly.eval_at_point(CirclePoint { x, y }); - - assert_eq!(eval, SecureField::one()); - } - - #[test] - fn test_evaluate_2_coeffs() { - let domain = CanonicCoset::new(1).circle_domain(); - let poly = CpuCirclePoly::new((1..=2).map(BaseField::from).collect()); - - let evaluation = poly.clone().evaluate(domain).bit_reverse(); - - for (i, (p, eval)) in zip(domain, evaluation).enumerate() { - let eval: SecureField = eval.into(); - assert_eq!(eval, poly.eval_at_point(p.into_ef()), "mismatch at i={i}"); - } - } - - #[test] - fn test_evaluate_4_coeffs() { - let domain = CanonicCoset::new(2).circle_domain(); - let poly = CpuCirclePoly::new((1..=4).map(BaseField::from).collect()); - - let evaluation = poly.clone().evaluate(domain).bit_reverse(); - - for (i, (x, eval)) in zip(domain, evaluation).enumerate() { - let eval: SecureField = eval.into(); - assert_eq!(eval, poly.eval_at_point(x.into_ef()), "mismatch at i={i}"); - } - } - - #[test] - fn test_evaluate_8_coeffs() { - let domain = CanonicCoset::new(3).circle_domain(); - let poly = CpuCirclePoly::new((1..=8).map(BaseField::from).collect()); - - let evaluation = poly.clone().evaluate(domain).bit_reverse(); - - for (i, (x, eval)) in zip(domain, evaluation).enumerate() { - let eval: SecureField = eval.into(); - assert_eq!(eval, poly.eval_at_point(x.into_ef()), "mismatch at i={i}"); - } - } - - #[test] - fn test_interpolate_2_evals() { - let poly = CpuCirclePoly::new(vec![BaseField::one(), BaseField::from(2)]); - let domain = CanonicCoset::new(1).circle_domain(); - let evals = poly.clone().evaluate(domain); - - let interpolated_poly = evals.interpolate(); - - assert_eq!(interpolated_poly.coeffs, poly.coeffs); - } - - #[test] - fn test_interpolate_4_evals() { - let poly = CpuCirclePoly::new((1..=4).map(BaseField::from).collect()); - let domain = CanonicCoset::new(2).circle_domain(); - let evals = poly.clone().evaluate(domain); - - let interpolated_poly = evals.interpolate(); - - assert_eq!(interpolated_poly.coeffs, poly.coeffs); - } - - #[test] - fn test_interpolate_8_evals() { - let poly = CpuCirclePoly::new((1..=8).map(BaseField::from).collect()); - let domain = CanonicCoset::new(3).circle_domain(); - let evals = poly.clone().evaluate(domain); - - let interpolated_poly = evals.interpolate(); - - assert_eq!(interpolated_poly.coeffs, poly.coeffs); - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/cpu/fri.rs b/Stwo_wrapper/crates/prover/src/core/backend/cpu/fri.rs deleted file mode 100644 index 693fb99..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/cpu/fri.rs +++ /dev/null @@ -1,144 +0,0 @@ -use super::CpuBackend; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::secure_column::SecureColumnByCoords; -use crate::core::fri::{fold_circle_into_line, fold_line, FriOps}; -use crate::core::poly::circle::SecureEvaluation; -use crate::core::poly::line::LineEvaluation; -use crate::core::poly::twiddles::TwiddleTree; -use crate::core::poly::BitReversedOrder; - -// TODO(spapini): Optimized these functions as well. -impl FriOps for CpuBackend { - fn fold_line( - eval: &LineEvaluation, - alpha: SecureField, - _twiddles: &TwiddleTree, - ) -> LineEvaluation { - fold_line(eval, alpha) - } - fn fold_circle_into_line( - dst: &mut LineEvaluation, - src: &SecureEvaluation, - alpha: SecureField, - _twiddles: &TwiddleTree, - ) { - fold_circle_into_line(dst, src, alpha) - } - - fn decompose( - eval: &SecureEvaluation, - ) -> (SecureEvaluation, SecureField) { - let lambda = Self::decomposition_coefficient(eval); - let mut g_values = unsafe { SecureColumnByCoords::::uninitialized(eval.len()) }; - - let domain_size = eval.len(); - let half_domain_size = domain_size / 2; - - for i in 0..half_domain_size { - let x = eval.values.at(i); - let val = x - lambda; - g_values.set(i, val); - } - for i in half_domain_size..domain_size { - let x = eval.values.at(i); - let val = x + lambda; - g_values.set(i, val); - } - - let g = SecureEvaluation::new(eval.domain, g_values); - (g, lambda) - } -} - -impl CpuBackend { - /// Used to decompose a general polynomial to a polynomial inside the fft-space, and - /// the remainder terms. - /// A coset-diff on a [`CirclePoly`] that is in the FFT space will return zero. - /// - /// Let N be the domain size, Let h be a coset size N/2. Using lemma #7 from the CircleStark - /// paper, = lambda = lambda\*N => lambda = f(0)\*V_h(0) + f(1)*V_h(1) + .. + - /// f(N-1)\*V_h(N-1). The Vanishing polynomial of a cannonic coset sized half the circle - /// domain,evaluated on the circle domain, is [(1, -1, -1, 1)] repeating. This becomes - /// alternating [+-1] in our NaturalOrder, and [(+, +, +, ... , -, -)] in bit reverse. - /// Explicitly, lambda\*N = sum(+f(0..N/2)) + sum(-f(N/2..)). - /// - /// # Warning - /// This function assumes the blowupfactor is 2 - /// - /// [`CirclePoly`]: crate::core::poly::circle::CirclePoly - fn decomposition_coefficient(eval: &SecureEvaluation) -> SecureField { - let domain_size = 1 << eval.domain.log_size(); - let half_domain_size = domain_size / 2; - - // eval is in bit-reverse, hence all the positive factors are in the first half, opposite to - // the latter. - let a_sum = (0..half_domain_size) - .map(|i| eval.values.at(i)) - .sum::(); - let b_sum = (half_domain_size..domain_size) - .map(|i| eval.values.at(i)) - .sum::(); - - // lambda = sum(+-f(p)) / 2N. - (a_sum - b_sum) / BaseField::from_u32_unchecked(domain_size as u32) - } -} - -#[cfg(test)] -mod tests { - use num_traits::Zero; - - use crate::core::backend::cpu::{CpuCircleEvaluation, CpuCirclePoly}; - use crate::core::backend::CpuBackend; - use crate::core::fields::m31::BaseField; - use crate::core::fields::qm31::SecureField; - use crate::core::fields::secure_column::SecureColumnByCoords; - use crate::core::fri::FriOps; - use crate::core::poly::circle::{CanonicCoset, SecureEvaluation}; - use crate::core::poly::BitReversedOrder; - use crate::m31; - - #[test] - fn decompose_coeff_out_fft_space_test() { - for domain_log_size in 5..12 { - let domain_log_half_size = domain_log_size - 1; - let s = CanonicCoset::new(domain_log_size); - let domain = s.circle_domain(); - - let mut coeffs = vec![BaseField::zero(); 1 << domain_log_size]; - - // Polynomial is out of FFT space. - coeffs[1 << domain_log_half_size] = m31!(1); - assert!(!CpuCirclePoly::new(coeffs.clone()).is_in_fft_space(domain_log_half_size)); - - let poly = CpuCirclePoly::new(coeffs); - let values = poly.evaluate(domain); - let secure_column = SecureColumnByCoords { - columns: [ - values.values.clone(), - values.values.clone(), - values.values.clone(), - values.values.clone(), - ], - }; - let secure_eval = SecureEvaluation::::new( - domain, - secure_column.clone(), - ); - - let (g, lambda) = CpuBackend::decompose(&secure_eval); - - // Sanity check. - assert_ne!(lambda, SecureField::zero()); - - // Assert the new polynomial is in the FFT space. - for i in 0..4 { - let basefield_column = g.columns[i].clone(); - let eval = CpuCircleEvaluation::new(domain, basefield_column); - let coeffs = eval.interpolate().coeffs; - assert!(CpuCirclePoly::new(coeffs).is_in_fft_space(domain_log_half_size)); - } - } - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/cpu/grind.rs b/Stwo_wrapper/crates/prover/src/core/backend/cpu/grind.rs deleted file mode 100644 index c5d27a1..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/cpu/grind.rs +++ /dev/null @@ -1,18 +0,0 @@ -use super::CpuBackend; -use crate::core::channel::Channel; -use crate::core::proof_of_work::GrindOps; - -impl GrindOps for CpuBackend { - fn grind(channel: &C, pow_bits: u32) -> u64 { - // TODO(spapini): This is a naive implementation. Optimize it. - let mut nonce = 0; - loop { - let mut channel = channel.clone(); - channel.mix_nonce(nonce); - if channel.trailing_zeros() >= pow_bits { - return nonce; - } - nonce += 1; - } - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/cpu/lookups/gkr.rs b/Stwo_wrapper/crates/prover/src/core/backend/cpu/lookups/gkr.rs deleted file mode 100644 index ae9ab6b..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/cpu/lookups/gkr.rs +++ /dev/null @@ -1,448 +0,0 @@ -use std::ops::Index; - -use num_traits::{One, Zero}; - -use crate::core::backend::CpuBackend; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::{ExtensionOf, Field}; -use crate::core::lookups::gkr_prover::{ - correct_sum_as_poly_in_first_variable, EqEvals, GkrMultivariatePolyOracle, GkrOps, Layer, -}; -use crate::core::lookups::mle::{Mle, MleOps}; -use crate::core::lookups::sumcheck::MultivariatePolyOracle; -use crate::core::lookups::utils::{Fraction, Reciprocal, UnivariatePoly}; - -impl GkrOps for CpuBackend { - fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Mle { - Mle::new(gen_eq_evals(y, v)) - } - - fn next_layer(layer: &Layer) -> Layer { - match layer { - Layer::GrandProduct(layer) => next_grand_product_layer(layer), - Layer::LogUpGeneric { - numerators, - denominators, - } => next_logup_layer(MleExpr::Mle(numerators), denominators), - Layer::LogUpMultiplicities { - numerators, - denominators, - } => next_logup_layer(MleExpr::Mle(numerators), denominators), - Layer::LogUpSingles { denominators } => { - next_logup_layer(MleExpr::Constant(BaseField::one()), denominators) - } - } - } - - fn sum_as_poly_in_first_variable( - h: &GkrMultivariatePolyOracle<'_, Self>, - claim: SecureField, - ) -> UnivariatePoly { - let n_variables = h.n_variables(); - assert!(!n_variables.is_zero()); - let n_terms = 1 << (n_variables - 1); - let eq_evals = h.eq_evals.as_ref(); - // Vector used to generate evaluations of `eq(x, y)` for `x` in the boolean hypercube. - let y = eq_evals.y(); - let lambda = h.lambda; - - let (mut eval_at_0, mut eval_at_2) = match &h.input_layer { - Layer::GrandProduct(col) => eval_grand_product_sum(eq_evals, col, n_terms), - Layer::LogUpGeneric { - numerators, - denominators, - } => eval_logup_sum(eq_evals, numerators, denominators, n_terms, lambda), - Layer::LogUpMultiplicities { - numerators, - denominators, - } => eval_logup_sum(eq_evals, numerators, denominators, n_terms, lambda), - Layer::LogUpSingles { denominators } => { - eval_logup_singles_sum(eq_evals, denominators, n_terms, lambda) - } - }; - - eval_at_0 *= h.eq_fixed_var_correction; - eval_at_2 *= h.eq_fixed_var_correction; - correct_sum_as_poly_in_first_variable(eval_at_0, eval_at_2, claim, y, n_variables) - } -} - -/// Evaluates `sum_x eq(({0}^|r|, 0, x), y) * inp(r, t, x, 0) * inp(r, t, x, 1)` at `t=0` and `t=2`. -/// -/// Output of the form: `(eval_at_0, eval_at_2)`. -fn eval_grand_product_sum( - eq_evals: &EqEvals, - input_layer: &Mle, - n_terms: usize, -) -> (SecureField, SecureField) { - let mut eval_at_0 = SecureField::zero(); - let mut eval_at_2 = SecureField::zero(); - - for i in 0..n_terms { - // Input polynomial at points `(r, {0, 1, 2}, bits(i), {0, 1})`. - let inp_at_r0i0 = input_layer[i * 2]; - let inp_at_r0i1 = input_layer[i * 2 + 1]; - let inp_at_r1i0 = input_layer[(n_terms + i) * 2]; - let inp_at_r1i1 = input_layer[(n_terms + i) * 2 + 1]; - // Note `inp(r, t, x) = eq(t, 0) * inp(r, 0, x) + eq(t, 1) * inp(r, 1, x)` - // => `inp(r, 2, x) = 2 * inp(r, 1, x) - inp(r, 0, x)` - // TODO(andrew): Consider evaluation at `1/2` to save an addition operation since - // `inp(r, 1/2, x) = 1/2 * (inp(r, 1, x) + inp(r, 0, x))`. `1/2 * ...` can be factored - // outside the loop. - let inp_at_r2i0 = inp_at_r1i0.double() - inp_at_r0i0; - let inp_at_r2i1 = inp_at_r1i1.double() - inp_at_r0i1; - - // Product polynomial `prod(x) = inp(x, 0) * inp(x, 1)` at points `(r, {0, 2}, bits(i))`. - let prod_at_r2i = inp_at_r2i0 * inp_at_r2i1; - let prod_at_r0i = inp_at_r0i0 * inp_at_r0i1; - - let eq_eval_at_0i = eq_evals[i]; - eval_at_0 += eq_eval_at_0i * prod_at_r0i; - eval_at_2 += eq_eval_at_0i * prod_at_r2i; - } - - (eval_at_0, eval_at_2) -} - -/// Evaluates `sum_x eq(({0}^|r|, 0, x), y) * (inp_numer(r, t, x, 0) * inp_denom(r, t, x, 1) + -/// inp_numer(r, t, x, 1) * inp_denom(r, t, x, 0) + lambda * inp_denom(r, t, x, 0) * inp_denom(r, t, -/// x, 1))` at `t=0` and `t=2`. -/// -/// Output of the form: `(eval_at_0, eval_at_2)`. -fn eval_logup_sum( - eq_evals: &EqEvals, - input_numerators: &Mle, - input_denominators: &Mle, - n_terms: usize, - lambda: SecureField, -) -> (SecureField, SecureField) -where - SecureField: ExtensionOf + Field, -{ - let mut eval_at_0 = SecureField::zero(); - let mut eval_at_2 = SecureField::zero(); - - for i in 0..n_terms { - // Input polynomials at points `(r, {0, 1, 2}, bits(i), {0, 1})`. - let inp_numer_at_r0i0 = input_numerators[i * 2]; - let inp_denom_at_r0i0 = input_denominators[i * 2]; - let inp_numer_at_r0i1 = input_numerators[i * 2 + 1]; - let inp_denom_at_r0i1 = input_denominators[i * 2 + 1]; - let inp_numer_at_r1i0 = input_numerators[(n_terms + i) * 2]; - let inp_denom_at_r1i0 = input_denominators[(n_terms + i) * 2]; - let inp_numer_at_r1i1 = input_numerators[(n_terms + i) * 2 + 1]; - let inp_denom_at_r1i1 = input_denominators[(n_terms + i) * 2 + 1]; - // Note `inp_denom(r, t, x) = eq(t, 0) * inp_denom(r, 0, x) + eq(t, 1) * inp_denom(r, 1, x)` - // => `inp_denom(r, 2, x) = 2 * inp_denom(r, 1, x) - inp_denom(r, 0, x)` - let inp_numer_at_r2i0 = inp_numer_at_r1i0.double() - inp_numer_at_r0i0; - let inp_denom_at_r2i0 = inp_denom_at_r1i0.double() - inp_denom_at_r0i0; - let inp_numer_at_r2i1 = inp_numer_at_r1i1.double() - inp_numer_at_r0i1; - let inp_denom_at_r2i1 = inp_denom_at_r1i1.double() - inp_denom_at_r0i1; - - // Fraction addition polynomials: - // - `numer(x) = inp_numer(x, 0) * inp_denom(x, 1) + inp_numer(x, 1) * inp_denom(x, 0)` - // - `denom(x) = inp_denom(x, 1) * inp_denom(x, 0)` - // at points `(r, {0, 2}, bits(i))`. - let Fraction { - numerator: numer_at_r0i, - denominator: denom_at_r0i, - } = Fraction::new(inp_numer_at_r0i0, inp_denom_at_r0i0) - + Fraction::new(inp_numer_at_r0i1, inp_denom_at_r0i1); - let Fraction { - numerator: numer_at_r2i, - denominator: denom_at_r2i, - } = Fraction::new(inp_numer_at_r2i0, inp_denom_at_r2i0) - + Fraction::new(inp_numer_at_r2i1, inp_denom_at_r2i1); - - let eq_eval_at_0i = eq_evals[i]; - eval_at_0 += eq_eval_at_0i * (numer_at_r0i + lambda * denom_at_r0i); - eval_at_2 += eq_eval_at_0i * (numer_at_r2i + lambda * denom_at_r2i); - } - - (eval_at_0, eval_at_2) -} - -/// Evaluates `sum_x eq(({0}^|r|, 0, x), y) * (inp_denom(r, t, x, 1) + inp_denom(r, t, x, 0) + -/// lambda * inp_denom(r, t, x, 0) * inp_denom(r, t, x, 1))` at `t=0` and `t=2`. -/// -/// Output of the form: `(eval_at_0, eval_at_2)`. -fn eval_logup_singles_sum( - eq_evals: &EqEvals, - input_denominators: &Mle, - n_terms: usize, - lambda: SecureField, -) -> (SecureField, SecureField) { - let mut eval_at_0 = SecureField::zero(); - let mut eval_at_2 = SecureField::zero(); - - for i in 0..n_terms { - // Input polynomial at points `(r, {0, 1, 2}, bits(i), {0, 1})`. - let inp_denom_at_r0i0 = input_denominators[i * 2]; - let inp_denom_at_r0i1 = input_denominators[i * 2 + 1]; - let inp_denom_at_r1i0 = input_denominators[(n_terms + i) * 2]; - let inp_denom_at_r1i1 = input_denominators[(n_terms + i) * 2 + 1]; - // Note `inp_denom(r, t, x) = eq(t, 0) * inp_denom(r, 0, x) + eq(t, 1) * inp_denom(r, 1, x)` - // => `inp_denom(r, 2, x) = 2 * inp_denom(r, 1, x) - inp_denom(r, 0, x)` - let inp_denom_at_r2i0 = inp_denom_at_r1i0.double() - inp_denom_at_r0i0; - let inp_denom_at_r2i1 = inp_denom_at_r1i1.double() - inp_denom_at_r0i1; - - // Fraction addition polynomials at points: - // - `numer(x) = inp_denom(x, 1) + inp_denom(x, 0)` - // - `denom(x) = inp_denom(x, 1) * inp_denom(x, 0)` - // at points `(r, {0, 2}, bits(i))`. - let Fraction { - numerator: numer_at_r0i, - denominator: denom_at_r0i, - } = Reciprocal::new(inp_denom_at_r0i0) + Reciprocal::new(inp_denom_at_r0i1); - let Fraction { - numerator: numer_at_r2i, - denominator: denom_at_r2i, - } = Reciprocal::new(inp_denom_at_r2i0) + Reciprocal::new(inp_denom_at_r2i1); - - let eq_eval_at_0i = eq_evals[i]; - eval_at_0 += eq_eval_at_0i * (numer_at_r0i + lambda * denom_at_r0i); - eval_at_2 += eq_eval_at_0i * (numer_at_r2i + lambda * denom_at_r2i); - } - - (eval_at_0, eval_at_2) -} - -/// Returns evaluations `eq(x, y) * v` for all `x` in `{0, 1}^n`. -/// -/// Evaluations are returned in bit-reversed order. -pub fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Vec { - let mut evals = Vec::with_capacity(1 << y.len()); - evals.push(v); - - for &y_i in y.iter().rev() { - for j in 0..evals.len() { - // `lhs[j] = eq(0, y_i) * c[i]` - // `rhs[j] = eq(1, y_i) * c[i]` - let tmp = evals[j] * y_i; - evals.push(tmp); - evals[j] -= tmp; - } - } - - evals -} - -fn next_grand_product_layer(layer: &Mle) -> Layer { - let res = layer.array_chunks().map(|&[a, b]| a * b).collect(); - Layer::GrandProduct(Mle::new(res)) -} - -fn next_logup_layer( - numerators: MleExpr<'_, F>, - denominators: &Mle, -) -> Layer -where - F: Field, - SecureField: ExtensionOf, - CpuBackend: MleOps, -{ - let half_n = 1 << (denominators.n_variables() - 1); - let mut next_numerators = Vec::with_capacity(half_n); - let mut next_denominators = Vec::with_capacity(half_n); - - for i in 0..half_n { - let a = Fraction::new(numerators[i * 2], denominators[i * 2]); - let b = Fraction::new(numerators[i * 2 + 1], denominators[i * 2 + 1]); - let res = a + b; - next_numerators.push(res.numerator); - next_denominators.push(res.denominator); - } - - Layer::LogUpGeneric { - numerators: Mle::new(next_numerators), - denominators: Mle::new(next_denominators), - } -} - -enum MleExpr<'a, F: Field> { - Constant(F), - Mle(&'a Mle), -} - -impl<'a, F: Field> Index for MleExpr<'a, F> { - type Output = F; - - fn index(&self, index: usize) -> &F { - match self { - Self::Constant(v) => v, - Self::Mle(mle) => &mle[index], - } - } -} - -#[cfg(test)] -mod tests { - use std::iter::zip; - - use num_traits::{One, Zero}; - use rand::rngs::SmallRng; - use rand::{Rng, SeedableRng}; - - use crate::core::backend::CpuBackend; - use crate::core::channel::Channel; - use crate::core::fields::m31::BaseField; - use crate::core::fields::qm31::SecureField; - use crate::core::lookups::gkr_prover::{prove_batch, GkrOps, Layer}; - use crate::core::lookups::gkr_verifier::{partially_verify_batch, Gate, GkrArtifact, GkrError}; - use crate::core::lookups::mle::Mle; - use crate::core::lookups::utils::{eq, Fraction}; - use crate::core::test_utils::test_channel; - - #[test] - fn gen_eq_evals() { - let zero = SecureField::zero(); - let one = SecureField::one(); - let two = BaseField::from(2).into(); - let y = [7, 3].map(|v| BaseField::from(v).into()); - - let eq_evals = CpuBackend::gen_eq_evals(&y, two); - - assert_eq!( - *eq_evals, - [ - eq(&[zero, zero], &y) * two, - eq(&[zero, one], &y) * two, - eq(&[one, zero], &y) * two, - eq(&[one, one], &y) * two, - ] - ); - } - - #[test] - fn grand_product_works() -> Result<(), GkrError> { - const N: usize = 1 << 5; - let values = test_channel().draw_felts(N); - let product = values.iter().product::(); - let col = Mle::::new(values); - let input_layer = Layer::GrandProduct(col.clone()); - let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]); - - let GkrArtifact { - ood_point: r, - claims_to_verify_by_instance, - n_variables_by_instance: _, - } = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut test_channel())?; - - assert_eq!(proof.output_claims_by_instance, [vec![product]]); - assert_eq!(claims_to_verify_by_instance, [vec![col.eval_at_point(&r)]]); - Ok(()) - } - - #[test] - fn logup_with_generic_trace_works() -> Result<(), GkrError> { - const N: usize = 1 << 5; - let mut rng = SmallRng::seed_from_u64(0); - let numerator_values = (0..N).map(|_| rng.gen()).collect::>(); - let denominator_values = (0..N).map(|_| rng.gen()).collect::>(); - let sum = zip(&numerator_values, &denominator_values) - .map(|(&n, &d)| Fraction::new(n, d)) - .sum::>(); - let numerators = Mle::::new(numerator_values); - let denominators = Mle::::new(denominator_values); - let top_layer = Layer::LogUpGeneric { - numerators: numerators.clone(), - denominators: denominators.clone(), - }; - let (proof, _) = prove_batch(&mut test_channel(), vec![top_layer]); - - let GkrArtifact { - ood_point, - claims_to_verify_by_instance, - n_variables_by_instance: _, - } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; - - assert_eq!(claims_to_verify_by_instance.len(), 1); - assert_eq!(proof.output_claims_by_instance.len(), 1); - assert_eq!( - claims_to_verify_by_instance[0], - [ - numerators.eval_at_point(&ood_point), - denominators.eval_at_point(&ood_point) - ] - ); - assert_eq!( - proof.output_claims_by_instance[0], - [sum.numerator, sum.denominator] - ); - Ok(()) - } - - #[test] - fn logup_with_singles_trace_works() -> Result<(), GkrError> { - const N: usize = 1 << 5; - let mut rng = SmallRng::seed_from_u64(0); - let denominator_values = (0..N).map(|_| rng.gen()).collect::>(); - let sum = denominator_values - .iter() - .map(|&d| Fraction::new(SecureField::one(), d)) - .sum::>(); - let denominators = Mle::::new(denominator_values); - let top_layer = Layer::LogUpSingles { - denominators: denominators.clone(), - }; - let (proof, _) = prove_batch(&mut test_channel(), vec![top_layer]); - - let GkrArtifact { - ood_point, - claims_to_verify_by_instance, - n_variables_by_instance: _, - } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; - - assert_eq!(claims_to_verify_by_instance.len(), 1); - assert_eq!(proof.output_claims_by_instance.len(), 1); - assert_eq!( - claims_to_verify_by_instance[0], - [SecureField::one(), denominators.eval_at_point(&ood_point)] - ); - assert_eq!( - proof.output_claims_by_instance[0], - [sum.numerator, sum.denominator] - ); - Ok(()) - } - - #[test] - fn logup_with_multiplicities_trace_works() -> Result<(), GkrError> { - const N: usize = 1 << 5; - let mut rng = SmallRng::seed_from_u64(0); - let numerator_values = (0..N).map(|_| rng.gen()).collect::>(); - let denominator_values = (0..N).map(|_| rng.gen()).collect::>(); - let sum = zip(&numerator_values, &denominator_values) - .map(|(&n, &d)| Fraction::new(n.into(), d)) - .sum::>(); - let numerators = Mle::::new(numerator_values); - let denominators = Mle::::new(denominator_values); - let top_layer = Layer::LogUpMultiplicities { - numerators: numerators.clone(), - denominators: denominators.clone(), - }; - let (proof, _) = prove_batch(&mut test_channel(), vec![top_layer]); - - let GkrArtifact { - ood_point, - claims_to_verify_by_instance, - n_variables_by_instance: _, - } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; - - assert_eq!(claims_to_verify_by_instance.len(), 1); - assert_eq!(proof.output_claims_by_instance.len(), 1); - assert_eq!( - claims_to_verify_by_instance[0], - [ - numerators.eval_at_point(&ood_point), - denominators.eval_at_point(&ood_point) - ] - ); - assert_eq!( - proof.output_claims_by_instance[0], - [sum.numerator, sum.denominator] - ); - Ok(()) - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/cpu/lookups/mle.rs b/Stwo_wrapper/crates/prover/src/core/backend/cpu/lookups/mle.rs deleted file mode 100644 index 35d6632..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/cpu/lookups/mle.rs +++ /dev/null @@ -1,66 +0,0 @@ -use std::iter::zip; - -use num_traits::{One, Zero}; - -use crate::core::backend::CpuBackend; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::lookups::mle::{Mle, MleOps}; -use crate::core::lookups::sumcheck::MultivariatePolyOracle; -use crate::core::lookups::utils::{fold_mle_evals, UnivariatePoly}; - -impl MleOps for CpuBackend { - fn fix_first_variable( - mle: Mle, - assignment: SecureField, - ) -> Mle { - let midpoint = mle.len() / 2; - let (lhs_evals, rhs_evals) = mle.split_at(midpoint); - - let res = zip(lhs_evals, rhs_evals) - .map(|(&lhs_eval, &rhs_eval)| fold_mle_evals(assignment, lhs_eval, rhs_eval)) - .collect(); - - Mle::new(res) - } -} - -impl MleOps for CpuBackend { - fn fix_first_variable( - mle: Mle, - assignment: SecureField, - ) -> Mle { - let midpoint = mle.len() / 2; - let mut evals = mle.into_evals(); - - for i in 0..midpoint { - let lhs_eval = evals[i]; - let rhs_eval = evals[i + midpoint]; - evals[i] = fold_mle_evals(assignment, lhs_eval, rhs_eval); - } - - evals.truncate(midpoint); - - Mle::new(evals) - } -} - -impl MultivariatePolyOracle for Mle { - fn n_variables(&self) -> usize { - self.n_variables() - } - - fn sum_as_poly_in_first_variable(&self, claim: SecureField) -> UnivariatePoly { - let x0 = SecureField::zero(); - let x1 = SecureField::one(); - - let y0 = self[0..self.len() / 2].iter().sum(); - let y1 = claim - y0; - - UnivariatePoly::interpolate_lagrange(&[x0, x1], &[y0, y1]) - } - - fn fix_first_variable(self, challenge: SecureField) -> Self { - self.fix_first_variable(challenge) - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/cpu/lookups/mod.rs b/Stwo_wrapper/crates/prover/src/core/backend/cpu/lookups/mod.rs deleted file mode 100644 index cd8dedf..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/cpu/lookups/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod gkr; -mod mle; diff --git a/Stwo_wrapper/crates/prover/src/core/backend/cpu/mod.rs b/Stwo_wrapper/crates/prover/src/core/backend/cpu/mod.rs deleted file mode 100644 index 579b735..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/cpu/mod.rs +++ /dev/null @@ -1,105 +0,0 @@ -mod accumulation; -mod blake2s; -mod circle; -mod fri; -mod grind; -pub mod lookups; -#[cfg(not(target_arch = "wasm32"))] -mod poseidon252; -pub mod quotients; -#[cfg(not(target_arch = "wasm32"))] -mod poseidon_bls; - -use std::fmt::Debug; - -use serde::{Deserialize, Serialize}; - -use super::{Backend, BackendForChannel, Column, ColumnOps, FieldOps}; -use crate::core::fields::Field; -use crate::core::lookups::mle::Mle; -use crate::core::poly::circle::{CircleEvaluation, CirclePoly}; -use crate::core::utils::bit_reverse; -use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; -#[cfg(not(target_arch = "wasm32"))] -use crate::core::vcs::poseidon252_merkle::Poseidon252MerkleChannel; - -#[cfg(not(target_arch = "wasm32"))] -use crate::core::vcs::poseidon_bls_merkle::PoseidonBLSMerkleChannel; - -#[derive(Copy, Clone, Debug, Deserialize, Serialize)] -pub struct CpuBackend; - -impl Backend for CpuBackend {} -impl BackendForChannel for CpuBackend {} -#[cfg(not(target_arch = "wasm32"))] -impl BackendForChannel for CpuBackend {} - -#[cfg(not(target_arch = "wasm32"))] -impl BackendForChannel for CpuBackend {} - -impl ColumnOps for CpuBackend { - type Column = Vec; - - fn bit_reverse_column(column: &mut Self::Column) { - bit_reverse(column) - } -} - -impl FieldOps for CpuBackend { - /// Batch inversion using the Montgomery's trick. - // TODO(Ohad): Benchmark this function. - fn batch_inverse(column: &Self::Column, dst: &mut Self::Column) { - F::batch_inverse(column, &mut dst[..]); - } -} - -impl Column for Vec { - fn zeros(len: usize) -> Self { - vec![T::default(); len] - } - #[allow(clippy::uninit_vec)] - unsafe fn uninitialized(length: usize) -> Self { - let mut data = Vec::with_capacity(length); - data.set_len(length); - data - } - fn to_cpu(&self) -> Vec { - self.clone() - } - fn len(&self) -> usize { - self.len() - } - fn at(&self, index: usize) -> T { - self[index].clone() - } - fn set(&mut self, index: usize, value: T) { - self[index] = value; - } -} - -pub type CpuCirclePoly = CirclePoly; -pub type CpuCircleEvaluation = CircleEvaluation; -pub type CpuMle = Mle; - -#[cfg(test)] -mod tests { - use itertools::Itertools; - use rand::prelude::*; - use rand::rngs::SmallRng; - - use crate::core::backend::{Column, CpuBackend, FieldOps}; - use crate::core::fields::qm31::QM31; - use crate::core::fields::FieldExpOps; - - #[test] - fn batch_inverse_test() { - let mut rng = SmallRng::seed_from_u64(0); - let column = rng.gen::<[QM31; 16]>().to_vec(); - let expected = column.iter().map(|e| e.inverse()).collect_vec(); - let mut dst = Column::zeros(column.len()); - - CpuBackend::batch_inverse(&column, &mut dst); - - assert_eq!(expected, dst); - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/cpu/poseidon252.rs b/Stwo_wrapper/crates/prover/src/core/backend/cpu/poseidon252.rs deleted file mode 100644 index 8cc5dd9..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/cpu/poseidon252.rs +++ /dev/null @@ -1,24 +0,0 @@ -use itertools::Itertools; -use starknet_ff::FieldElement as FieldElement252; - -use super::CpuBackend; -use crate::core::fields::m31::BaseField; -use crate::core::vcs::ops::{MerkleHasher, MerkleOps}; -use crate::core::vcs::poseidon252_merkle::Poseidon252MerkleHasher; - -impl MerkleOps for CpuBackend { - fn commit_on_layer( - log_size: u32, - prev_layer: Option<&Vec>, - columns: &[&Vec], - ) -> Vec { - (0..(1 << log_size)) - .map(|i| { - Poseidon252MerkleHasher::hash_node( - prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])), - &columns.iter().map(|column| column[i]).collect_vec(), - ) - }) - .collect() - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/cpu/poseidon_bls.rs b/Stwo_wrapper/crates/prover/src/core/backend/cpu/poseidon_bls.rs deleted file mode 100644 index a6cca33..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/cpu/poseidon_bls.rs +++ /dev/null @@ -1,24 +0,0 @@ -use itertools::Itertools; -use ark_bls12_381::Fr as BlsFr; - -use super::CpuBackend; -use crate::core::fields::m31::BaseField; -use crate::core::vcs::ops::{MerkleHasher, MerkleOps}; -use crate::core::vcs::poseidon_bls_merkle::PoseidonBLSMerkleHasher; - -impl MerkleOps for CpuBackend { - fn commit_on_layer( - log_size: u32, - prev_layer: Option<&Vec>, - columns: &[&Vec], - ) -> Vec { - (0..(1 << log_size)) - .map(|i| { - PoseidonBLSMerkleHasher::hash_node( - prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])), - &columns.iter().map(|column| column[i]).collect_vec(), - ) - }) - .collect() - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/cpu/quotients.rs b/Stwo_wrapper/crates/prover/src/core/backend/cpu/quotients.rs deleted file mode 100644 index 17cc007..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/cpu/quotients.rs +++ /dev/null @@ -1,210 +0,0 @@ -use itertools::{izip, zip_eq}; -use num_traits::{One, Zero}; - -use super::CpuBackend; -use crate::core::circle::CirclePoint; -use crate::core::constraints::complex_conjugate_line_coeffs; -use crate::core::fields::cm31::CM31; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::secure_column::SecureColumnByCoords; -use crate::core::fields::FieldExpOps; -use crate::core::pcs::quotients::{ColumnSampleBatch, PointSample, QuotientOps}; -use crate::core::poly::circle::{CircleDomain, CircleEvaluation, SecureEvaluation}; -use crate::core::poly::BitReversedOrder; -use crate::core::utils::{bit_reverse, bit_reverse_index}; - -impl QuotientOps for CpuBackend { - fn accumulate_quotients( - domain: CircleDomain, - columns: &[&CircleEvaluation], - random_coeff: SecureField, - sample_batches: &[ColumnSampleBatch], - _log_blowup_factor: u32, - ) -> SecureEvaluation { - let mut values = unsafe { SecureColumnByCoords::uninitialized(domain.size()) }; - let quotient_constants = quotient_constants(sample_batches, random_coeff, domain); - - for row in 0..domain.size() { - let domain_point = domain.at(bit_reverse_index(row, domain.log_size())); - let row_value = accumulate_row_quotients( - sample_batches, - columns, - "ient_constants, - row, - domain_point, - ); - values.set(row, row_value); - } - SecureEvaluation::new(domain, values) - } -} - -pub fn accumulate_row_quotients( - sample_batches: &[ColumnSampleBatch], - columns: &[&CircleEvaluation], - quotient_constants: &QuotientConstants, - row: usize, - domain_point: CirclePoint, -) -> SecureField { - let mut row_accumulator = SecureField::zero(); - for (sample_batch, line_coeffs, batch_coeff, denominator_inverses) in izip!( - sample_batches, - "ient_constants.line_coeffs, - "ient_constants.batch_random_coeffs, - "ient_constants.denominator_inverses - ) { - let mut numerator = SecureField::zero(); - for ((column_index, _), (a, b, c)) in zip_eq(&sample_batch.columns_and_values, line_coeffs) - { - let column = &columns[*column_index]; - let value = column[row] * *c; - // The numerator is a line equation passing through - // (sample_point.y, sample_value), (conj(sample_point), conj(sample_value)) - // evaluated at (domain_point.y, value). - // When substituting a polynomial in this line equation, we get a polynomial with a root - // at sample_point and conj(sample_point) if the original polynomial had the values - // sample_value and conj(sample_value) at these points. - let linear_term = *a * domain_point.y + *b; - numerator += value - linear_term; - } - - row_accumulator = - row_accumulator * *batch_coeff + numerator.mul_cm31(denominator_inverses[row]); - } - row_accumulator -} - -/// Precompute the complex conjugate line coefficients for each column in each sample batch. -/// Specifically, for the i-th (in a sample batch) column's numerator term -/// `alpha^i * (c * F(p) - (a * p.y + b))`, we precompute and return the constants: -/// (`alpha^i * a`, `alpha^i * b`, `alpha^i * c`). -pub fn column_line_coeffs( - sample_batches: &[ColumnSampleBatch], - random_coeff: SecureField, -) -> Vec> { - sample_batches - .iter() - .map(|sample_batch| { - let mut alpha = SecureField::one(); - sample_batch - .columns_and_values - .iter() - .map(|(_, sampled_value)| { - alpha *= random_coeff; - let sample = PointSample { - point: sample_batch.point, - value: *sampled_value, - }; - complex_conjugate_line_coeffs(&sample, alpha) - }) - .collect() - }) - .collect() -} - -/// Precompute the random coefficients used to linearly combine the batched quotients. -/// Specifically, for each sample batch we compute random_coeff^(number of columns in the batch), -/// which is used to linearly combine the batch with the next one. -pub fn batch_random_coeffs( - sample_batches: &[ColumnSampleBatch], - random_coeff: SecureField, -) -> Vec { - sample_batches - .iter() - .map(|sb| random_coeff.pow(sb.columns_and_values.len() as u128)) - .collect() -} - -fn denominator_inverses( - sample_batches: &[ColumnSampleBatch], - domain: CircleDomain, -) -> Vec> { - let mut flat_denominators = Vec::with_capacity(sample_batches.len() * domain.size()); - // We want a P to be on a line that passes through a point Pr + uPi in QM31^2, and its conjugate - // Pr - uPi. Thus, Pr - P is parallel to Pi. Or, (Pr - P).x * Pi.y - (Pr - P).y * Pi.x = 0. - for sample_batch in sample_batches { - // Extract Pr, Pi. - let prx = sample_batch.point.x.0; - let pry = sample_batch.point.y.0; - let pix = sample_batch.point.x.1; - let piy = sample_batch.point.y.1; - for row in 0..domain.size() { - let domain_point = domain.at(row); - flat_denominators.push((prx - domain_point.x) * piy - (pry - domain_point.y) * pix); - } - } - - let mut flat_denominator_inverses = vec![CM31::zero(); flat_denominators.len()]; - CM31::batch_inverse(&flat_denominators, &mut flat_denominator_inverses); - - flat_denominator_inverses - .chunks_mut(domain.size()) - .map(|denominator_inverses| { - bit_reverse(denominator_inverses); - denominator_inverses.to_vec() - }) - .collect() -} - -pub fn quotient_constants( - sample_batches: &[ColumnSampleBatch], - random_coeff: SecureField, - domain: CircleDomain, -) -> QuotientConstants { - let line_coeffs = column_line_coeffs(sample_batches, random_coeff); - let batch_random_coeffs = batch_random_coeffs(sample_batches, random_coeff); - let denominator_inverses = denominator_inverses(sample_batches, domain); - QuotientConstants { - line_coeffs, - batch_random_coeffs, - denominator_inverses, - } -} - -/// Holds the precomputed constant values used in each quotient evaluation. -pub struct QuotientConstants { - /// The line coefficients for each quotient numerator term. For more details see - /// [self::column_line_coeffs]. - pub line_coeffs: Vec>, - /// The random coefficients used to linearly combine the batched quotients For more details see - /// [self::batch_random_coeffs]. - pub batch_random_coeffs: Vec, - /// The inverses of the denominators of the quotients. - pub denominator_inverses: Vec>, -} - -#[cfg(test)] -mod tests { - use crate::core::backend::cpu::{CpuCircleEvaluation, CpuCirclePoly}; - use crate::core::backend::CpuBackend; - use crate::core::circle::SECURE_FIELD_CIRCLE_GEN; - use crate::core::pcs::quotients::{ColumnSampleBatch, QuotientOps}; - use crate::core::poly::circle::CanonicCoset; - use crate::{m31, qm31}; - - #[test] - fn test_quotients_are_low_degree() { - const LOG_SIZE: u32 = 7; - const LOG_BLOWUP_FACTOR: u32 = 1; - let polynomial = CpuCirclePoly::new((0..1 << LOG_SIZE).map(|i| m31!(i)).collect()); - let eval_domain = CanonicCoset::new(LOG_SIZE + 1).circle_domain(); - let eval = polynomial.evaluate(eval_domain); - let point = SECURE_FIELD_CIRCLE_GEN; - let value = polynomial.eval_at_point(point); - let coeff = qm31!(1, 2, 3, 4); - let quot_eval = CpuBackend::accumulate_quotients( - eval_domain, - &[&eval], - coeff, - &[ColumnSampleBatch { - point, - columns_and_values: vec![(0, value)], - }], - LOG_BLOWUP_FACTOR, - ); - let quot_poly_base_field = - CpuCircleEvaluation::new(eval_domain, quot_eval.columns[0].clone()).interpolate(); - assert!(quot_poly_base_field.is_in_fri_space(LOG_SIZE)); - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/mod.rs b/Stwo_wrapper/crates/prover/src/core/backend/mod.rs deleted file mode 100644 index f6eae91..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/mod.rs +++ /dev/null @@ -1,66 +0,0 @@ -use std::fmt::Debug; - -pub use cpu::CpuBackend; - -use super::air::accumulation::AccumulationOps; -use super::channel::MerkleChannel; -use super::fields::m31::BaseField; -use super::fields::qm31::SecureField; -use super::fields::FieldOps; -use super::fri::FriOps; -use super::lookups::gkr_prover::GkrOps; -use super::pcs::quotients::QuotientOps; -use super::poly::circle::PolyOps; -use super::proof_of_work::GrindOps; -use super::vcs::ops::MerkleOps; - -pub mod cpu; -pub mod simd; - -pub trait Backend: - Copy - + Clone - + Debug - + FieldOps - + FieldOps - + PolyOps - + QuotientOps - + FriOps - + AccumulationOps - + GkrOps -{ -} - -pub trait BackendForChannel: - Backend + MerkleOps + GrindOps -{ -} - -pub trait ColumnOps { - type Column: Column; - fn bit_reverse_column(column: &mut Self::Column); -} - -pub type Col = >::Column; - -// TODO(spapini): Consider removing the generic parameter and only support BaseField. -pub trait Column: Clone + Debug + FromIterator { - /// Creates a new column of zeros with the given length. - fn zeros(len: usize) -> Self; - /// Creates a new column of uninitialized values with the given length. - /// # Safety - /// The caller must ensure that the column is populated before being used. - unsafe fn uninitialized(len: usize) -> Self; - /// Returns a cpu vector of the column. - fn to_cpu(&self) -> Vec; - /// Returns the length of the column. - fn len(&self) -> usize; - /// Returns true if the column is empty. - fn is_empty(&self) -> bool { - self.len() == 0 - } - /// Retrieves the element at the given index. - fn at(&self, index: usize) -> T; - /// Sets the element at the given index. - fn set(&mut self, index: usize, value: T); -} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/accumulation.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/accumulation.rs deleted file mode 100644 index c9705df..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/simd/accumulation.rs +++ /dev/null @@ -1,12 +0,0 @@ -use super::SimdBackend; -use crate::core::air::accumulation::AccumulationOps; -use crate::core::fields::secure_column::SecureColumnByCoords; - -impl AccumulationOps for SimdBackend { - fn accumulate(column: &mut SecureColumnByCoords, other: &SecureColumnByCoords) { - for i in 0..column.packed_len() { - let res_coeff = unsafe { column.packed_at(i) + other.packed_at(i) }; - unsafe { column.set_packed(i, res_coeff) }; - } - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/bit_reverse.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/bit_reverse.rs deleted file mode 100644 index 13d6585..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/simd/bit_reverse.rs +++ /dev/null @@ -1,203 +0,0 @@ -use std::array; - -use super::column::{BaseColumn, SecureColumn}; -use super::m31::PackedBaseField; -use super::SimdBackend; -use crate::core::backend::ColumnOps; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::utils::{bit_reverse as cpu_bit_reverse, bit_reverse_index}; - -const VEC_BITS: u32 = 4; - -const W_BITS: u32 = 3; - -pub const MIN_LOG_SIZE: u32 = 2 * W_BITS + VEC_BITS; - -impl ColumnOps for SimdBackend { - type Column = BaseColumn; - - fn bit_reverse_column(column: &mut Self::Column) { - // Fallback to cpu bit_reverse. - if column.data.len().ilog2() < MIN_LOG_SIZE { - cpu_bit_reverse(column.as_mut_slice()); - return; - } - - bit_reverse_m31(&mut column.data); - } -} - -impl ColumnOps for SimdBackend { - type Column = SecureColumn; - - fn bit_reverse_column(_column: &mut SecureColumn) { - todo!() - } -} - -/// Bit reverses M31 values. -/// -/// Given an array `A[0..2^n)`, computes `B[i] = A[bit_reverse(i)]`. -pub fn bit_reverse_m31(data: &mut [PackedBaseField]) { - assert!(data.len().is_power_of_two()); - assert!(data.len().ilog2() >= MIN_LOG_SIZE); - - // Indices in the array are of the form v_h w_h a w_l v_l, with - // |v_h| = |v_l| = VEC_BITS, |w_h| = |w_l| = W_BITS, |a| = n - 2*W_BITS - VEC_BITS. - // The loops go over a, w_l, w_h, and then swaps the 16 by 16 values at: - // * w_h a w_l * <-> * rev(w_h a w_l) *. - // These are 1 or 2 chunks of 2^W_BITS contiguous `u32x16` vectors. - - let log_size = data.len().ilog2(); - let a_bits = log_size - 2 * W_BITS - VEC_BITS; - - // TODO(spapini): when doing multithreading, do it over a. - for a in 0u32..1 << a_bits { - for w_l in 0u32..1 << W_BITS { - let w_l_rev = w_l.reverse_bits() >> (u32::BITS - W_BITS); - for w_h in 0..w_l_rev + 1 { - let idx = ((((w_h << a_bits) | a) << W_BITS) | w_l) as usize; - let idx_rev = bit_reverse_index(idx, log_size - VEC_BITS); - - // In order to not swap twice, only swap if idx <= idx_rev. - if idx > idx_rev { - continue; - } - - // Read first chunk. - // TODO(spapini): Think about optimizing a_bits. - let chunk0 = array::from_fn(|i| unsafe { - *data.get_unchecked(idx + (i << (2 * W_BITS + a_bits))) - }); - let values0 = bit_reverse16(chunk0); - - if idx == idx_rev { - // Palindrome index. Write into the same chunk. - #[allow(clippy::needless_range_loop)] - for i in 0..16 { - unsafe { - *data.get_unchecked_mut(idx + (i << (2 * W_BITS + a_bits))) = - values0[i]; - } - } - continue; - } - - // Read bit reversed chunk. - let chunk1 = array::from_fn(|i| unsafe { - *data.get_unchecked(idx_rev + (i << (2 * W_BITS + a_bits))) - }); - let values1 = bit_reverse16(chunk1); - - for i in 0..16 { - unsafe { - *data.get_unchecked_mut(idx + (i << (2 * W_BITS + a_bits))) = values1[i]; - *data.get_unchecked_mut(idx_rev + (i << (2 * W_BITS + a_bits))) = - values0[i]; - } - } - } - } - } -} - -/// Bit reverses 256 M31 values, packed in 16 words of 16 elements each. -fn bit_reverse16(mut data: [PackedBaseField; 16]) -> [PackedBaseField; 16] { - // Denote the index of each element in the 16 packed M31 words as abcd:0123, - // where abcd is the index of the packed word and 0123 is the index of the element in the word. - // Bit reversal is achieved by applying the following permutation to the index for 4 times: - // abcd:0123 => 0abc:123d - // This is how it looks like at each iteration. - // abcd:0123 - // 0abc:123d - // 10ab:23dc - // 210a:3dcb - // 3210:dcba - for _ in 0..4 { - // Apply the abcd:0123 => 0abc:123d permutation. - // `interleave` allows us to interleave the first half of 2 words. - // For example, the second call interleaves 0010:0xyz (low half of register 2) with - // 0011:0xyz (low half of register 3), and stores the result in register 1 (0001). - // This results in - // 0001:xyz0 (even indices of register 1) <= 0010:0xyz (low half of register2), and - // 0001:xyz1 (odd indices of register 1) <= 0011:0xyz (low half of register 3) - // or 0001:xyzw <= 001w:0xyz. - let (d0, d8) = data[0].interleave(data[1]); - let (d1, d9) = data[2].interleave(data[3]); - let (d2, d10) = data[4].interleave(data[5]); - let (d3, d11) = data[6].interleave(data[7]); - let (d4, d12) = data[8].interleave(data[9]); - let (d5, d13) = data[10].interleave(data[11]); - let (d6, d14) = data[12].interleave(data[13]); - let (d7, d15) = data[14].interleave(data[15]); - data = [ - d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, d14, d15, - ]; - } - - data -} - -#[cfg(test)] -mod tests { - use itertools::Itertools; - - use super::{bit_reverse16, bit_reverse_m31, MIN_LOG_SIZE}; - use crate::core::backend::simd::column::BaseColumn; - use crate::core::backend::simd::m31::{PackedM31, N_LANES}; - use crate::core::backend::simd::SimdBackend; - use crate::core::backend::{Column, ColumnOps}; - use crate::core::fields::m31::BaseField; - use crate::core::utils::bit_reverse as cpu_bit_reverse; - - #[test] - fn test_bit_reverse16() { - let values: BaseColumn = (0..N_LANES * 16).map(BaseField::from).collect(); - let mut expected = values.to_cpu(); - cpu_bit_reverse(&mut expected); - - let res = bit_reverse16(values.data.try_into().unwrap()); - - assert_eq!(res.map(PackedM31::to_array).flatten(), expected); - } - - #[test] - fn bit_reverse_m31_works() { - const SIZE: usize = 1 << 15; - let data: Vec<_> = (0..SIZE).map(BaseField::from).collect(); - let mut expected = data.clone(); - cpu_bit_reverse(&mut expected); - - let mut res: BaseColumn = data.into_iter().collect(); - bit_reverse_m31(&mut res.data[..]); - - assert_eq!(res.to_cpu(), expected); - } - - #[test] - fn bit_reverse_small_column_works() { - const LOG_SIZE: u32 = MIN_LOG_SIZE - 1; - let column = (0..1 << LOG_SIZE).map(BaseField::from).collect_vec(); - let mut expected = column.clone(); - cpu_bit_reverse(&mut expected); - - let mut res = column.iter().copied().collect::(); - >::bit_reverse_column(&mut res); - - assert_eq!(res.to_cpu(), expected); - } - - #[test] - fn bit_reverse_large_column_works() { - const LOG_SIZE: u32 = MIN_LOG_SIZE; - let column = (0..1 << LOG_SIZE).map(BaseField::from).collect_vec(); - let mut expected = column.clone(); - cpu_bit_reverse(&mut expected); - - let mut res = column.iter().copied().collect::(); - >::bit_reverse_column(&mut res); - - assert_eq!(res.to_cpu(), expected); - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/blake2s.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/blake2s.rs deleted file mode 100644 index fbcfe89..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/simd/blake2s.rs +++ /dev/null @@ -1,412 +0,0 @@ -//! A SIMD implementation of the BLAKE2s compression function. -//! Based on . - -use std::array; -use std::iter::repeat; -use std::mem::transmute; -use std::simd::u32x16; - -use bytemuck::cast_slice; -use itertools::Itertools; -#[cfg(feature = "parallel")] -use rayon::prelude::*; - -use super::m31::{LOG_N_LANES, N_LANES}; -use super::SimdBackend; -use crate::core::backend::{Col, Column, ColumnOps}; -use crate::core::fields::m31::BaseField; -use crate::core::vcs::blake2_hash::Blake2sHash; -use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher; -use crate::core::vcs::ops::{MerkleHasher, MerkleOps}; - -const IV: [u32; 8] = [ - 0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19, -]; - -pub const SIGMA: [[u8; 16]; 10] = [ - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], - [14, 10, 4, 8, 9, 15, 13, 6, 1, 12, 0, 2, 11, 7, 5, 3], - [11, 8, 12, 0, 5, 2, 15, 13, 10, 14, 3, 6, 7, 1, 9, 4], - [7, 9, 3, 1, 13, 12, 11, 14, 2, 6, 5, 10, 4, 0, 15, 8], - [9, 0, 5, 7, 2, 4, 10, 15, 14, 1, 11, 12, 6, 8, 3, 13], - [2, 12, 6, 10, 0, 11, 8, 3, 4, 13, 7, 5, 15, 14, 1, 9], - [12, 5, 1, 15, 14, 13, 4, 10, 0, 7, 6, 3, 9, 2, 8, 11], - [13, 11, 7, 14, 12, 1, 3, 9, 5, 0, 15, 4, 8, 6, 2, 10], - [6, 15, 14, 9, 11, 3, 0, 8, 12, 2, 13, 7, 1, 4, 10, 5], - [10, 2, 8, 4, 7, 6, 1, 5, 15, 11, 9, 14, 3, 12, 13, 0], -]; - -impl ColumnOps for SimdBackend { - type Column = Vec; - - fn bit_reverse_column(_column: &mut Self::Column) { - unimplemented!() - } -} - -impl MerkleOps for SimdBackend { - fn commit_on_layer( - log_size: u32, - prev_layer: Option<&Vec>, - columns: &[&Col], - ) -> Vec { - if log_size < LOG_N_LANES { - #[cfg(not(feature = "parallel"))] - let iter = 0..1 << log_size; - - #[cfg(feature = "parallel")] - let iter = (0..1 << log_size).into_par_iter(); - - return iter - .map(|i| { - Blake2sMerkleHasher::hash_node( - prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])), - &columns.iter().map(|column| column.at(i)).collect_vec(), - ) - }) - .collect(); - } - - if let Some(prev_layer) = prev_layer { - assert_eq!(prev_layer.len(), 1 << (log_size + 1)); - } - - let zeros = u32x16::splat(0); - - // Commit to columns. - let mut res = vec![Blake2sHash::default(); 1 << log_size]; - #[cfg(not(feature = "parallel"))] - let iter = res.chunks_mut(1 << LOG_N_LANES); - - #[cfg(feature = "parallel")] - let iter = res.par_chunks_mut(1 << LOG_N_LANES); - - iter.enumerate().for_each(|(i, chunk)| { - let mut state: [u32x16; 8] = unsafe { std::mem::zeroed() }; - // Hash prev_layer, if exists. - if let Some(prev_layer) = prev_layer { - let prev_chunk_u32s = cast_slice::<_, u32>(&prev_layer[(i << 5)..((i + 1) << 5)]); - // Note: prev_layer might be unaligned. - let msgs: [u32x16; 16] = array::from_fn(|j| { - u32x16::from_array(std::array::from_fn(|k| prev_chunk_u32s[16 * j + k])) - }); - state = compress16(state, transpose_msgs(msgs), zeros, zeros, zeros, zeros); - } - - // Hash columns in chunks of 16. - let mut col_chunk_iter = columns.array_chunks(); - for col_chunk in &mut col_chunk_iter { - let msgs = col_chunk.map(|column| column.data[i].into_simd()); - state = compress16(state, msgs, zeros, zeros, zeros, zeros); - } - - // Hash remaining columns. - let remainder = col_chunk_iter.remainder(); - if !remainder.is_empty() { - let msgs = remainder - .iter() - .map(|column| column.data[i].into_simd()) - .chain(repeat(zeros)) - .take(N_LANES) - .collect_vec() - .try_into() - .unwrap(); - state = compress16(state, msgs, zeros, zeros, zeros, zeros); - } - let state: [Blake2sHash; 16] = unsafe { transmute(untranspose_states(state)) }; - chunk.copy_from_slice(&state); - }); - res - } -} - -/// Applies [`u32::rotate_right(N)`] to each element of the vector -/// -/// [`u32::rotate_right(N)`]: u32::rotate_right -#[inline(always)] -fn rotate(x: u32x16) -> u32x16 { - (x >> N) | (x << (u32::BITS - N)) -} - -// `inline(always)` can cause code parsing errors for wasm: "locals exceed maximum". -#[cfg_attr(not(target_arch = "wasm32"), inline(always))] -pub fn round(v: &mut [u32x16; 16], m: [u32x16; 16], r: usize) { - v[0] += m[SIGMA[r][0] as usize]; - v[1] += m[SIGMA[r][2] as usize]; - v[2] += m[SIGMA[r][4] as usize]; - v[3] += m[SIGMA[r][6] as usize]; - v[0] += v[4]; - v[1] += v[5]; - v[2] += v[6]; - v[3] += v[7]; - v[12] ^= v[0]; - v[13] ^= v[1]; - v[14] ^= v[2]; - v[15] ^= v[3]; - v[12] = rotate::<16>(v[12]); - v[13] = rotate::<16>(v[13]); - v[14] = rotate::<16>(v[14]); - v[15] = rotate::<16>(v[15]); - v[8] += v[12]; - v[9] += v[13]; - v[10] += v[14]; - v[11] += v[15]; - v[4] ^= v[8]; - v[5] ^= v[9]; - v[6] ^= v[10]; - v[7] ^= v[11]; - v[4] = rotate::<12>(v[4]); - v[5] = rotate::<12>(v[5]); - v[6] = rotate::<12>(v[6]); - v[7] = rotate::<12>(v[7]); - v[0] += m[SIGMA[r][1] as usize]; - v[1] += m[SIGMA[r][3] as usize]; - v[2] += m[SIGMA[r][5] as usize]; - v[3] += m[SIGMA[r][7] as usize]; - v[0] += v[4]; - v[1] += v[5]; - v[2] += v[6]; - v[3] += v[7]; - v[12] ^= v[0]; - v[13] ^= v[1]; - v[14] ^= v[2]; - v[15] ^= v[3]; - v[12] = rotate::<8>(v[12]); - v[13] = rotate::<8>(v[13]); - v[14] = rotate::<8>(v[14]); - v[15] = rotate::<8>(v[15]); - v[8] += v[12]; - v[9] += v[13]; - v[10] += v[14]; - v[11] += v[15]; - v[4] ^= v[8]; - v[5] ^= v[9]; - v[6] ^= v[10]; - v[7] ^= v[11]; - v[4] = rotate::<7>(v[4]); - v[5] = rotate::<7>(v[5]); - v[6] = rotate::<7>(v[6]); - v[7] = rotate::<7>(v[7]); - - v[0] += m[SIGMA[r][8] as usize]; - v[1] += m[SIGMA[r][10] as usize]; - v[2] += m[SIGMA[r][12] as usize]; - v[3] += m[SIGMA[r][14] as usize]; - v[0] += v[5]; - v[1] += v[6]; - v[2] += v[7]; - v[3] += v[4]; - v[15] ^= v[0]; - v[12] ^= v[1]; - v[13] ^= v[2]; - v[14] ^= v[3]; - v[15] = rotate::<16>(v[15]); - v[12] = rotate::<16>(v[12]); - v[13] = rotate::<16>(v[13]); - v[14] = rotate::<16>(v[14]); - v[10] += v[15]; - v[11] += v[12]; - v[8] += v[13]; - v[9] += v[14]; - v[5] ^= v[10]; - v[6] ^= v[11]; - v[7] ^= v[8]; - v[4] ^= v[9]; - v[5] = rotate::<12>(v[5]); - v[6] = rotate::<12>(v[6]); - v[7] = rotate::<12>(v[7]); - v[4] = rotate::<12>(v[4]); - v[0] += m[SIGMA[r][9] as usize]; - v[1] += m[SIGMA[r][11] as usize]; - v[2] += m[SIGMA[r][13] as usize]; - v[3] += m[SIGMA[r][15] as usize]; - v[0] += v[5]; - v[1] += v[6]; - v[2] += v[7]; - v[3] += v[4]; - v[15] ^= v[0]; - v[12] ^= v[1]; - v[13] ^= v[2]; - v[14] ^= v[3]; - v[15] = rotate::<8>(v[15]); - v[12] = rotate::<8>(v[12]); - v[13] = rotate::<8>(v[13]); - v[14] = rotate::<8>(v[14]); - v[10] += v[15]; - v[11] += v[12]; - v[8] += v[13]; - v[9] += v[14]; - v[5] ^= v[10]; - v[6] ^= v[11]; - v[7] ^= v[8]; - v[4] ^= v[9]; - v[5] = rotate::<7>(v[5]); - v[6] = rotate::<7>(v[6]); - v[7] = rotate::<7>(v[7]); - v[4] = rotate::<7>(v[4]); -} - -/// Transposes input chunks (16 chunks of 16 `u32`s each), to get 16 `u32x16`, each -/// representing 16 packed instances of a message word. -fn transpose_msgs(mut data: [u32x16; 16]) -> [u32x16; 16] { - // Index abcd:xyzw, refers to a specific word in data as follows: - // abcd - chunk index (in base 2) - // xyzw - word offset (in base 2) - // Transpose by applying 4 times the index permutation: - // abcd:xyzw => wabc:dxyz - // In other words, rotate the index to the right by 1. - for _ in 0..4 { - let (d0, d8) = data[0].deinterleave(data[1]); - let (d1, d9) = data[2].deinterleave(data[3]); - let (d2, d10) = data[4].deinterleave(data[5]); - let (d3, d11) = data[6].deinterleave(data[7]); - let (d4, d12) = data[8].deinterleave(data[9]); - let (d5, d13) = data[10].deinterleave(data[11]); - let (d6, d14) = data[12].deinterleave(data[13]); - let (d7, d15) = data[14].deinterleave(data[15]); - data = [ - d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, d14, d15, - ]; - } - - data -} - -fn untranspose_states(mut states: [u32x16; 8]) -> [u32x16; 8] { - // Index abc:xyzw, refers to a specific word in data as follows: - // abc - chunk index (in base 2) - // xyzw - word offset (in base 2) - // Transpose by applying 3 times the index permutation: - // abc:xyzw => bcx:yzwa - // In other words, rotate the index to the left by 1. - for _ in 0..3 { - let (d0, d1) = states[0].interleave(states[4]); - let (d2, d3) = states[1].interleave(states[5]); - let (d4, d5) = states[2].interleave(states[6]); - let (d6, d7) = states[3].interleave(states[7]); - states = [d0, d1, d2, d3, d4, d5, d6, d7]; - } - states -} - -/// Compresses 16 blake2s instances. -pub fn compress16( - h_vecs: [u32x16; 8], - msg_vecs: [u32x16; 16], - count_low: u32x16, - count_high: u32x16, - lastblock: u32x16, - lastnode: u32x16, -) -> [u32x16; 8] { - let mut v = [ - h_vecs[0], - h_vecs[1], - h_vecs[2], - h_vecs[3], - h_vecs[4], - h_vecs[5], - h_vecs[6], - h_vecs[7], - u32x16::splat(IV[0]), - u32x16::splat(IV[1]), - u32x16::splat(IV[2]), - u32x16::splat(IV[3]), - u32x16::splat(IV[4]) ^ count_low, - u32x16::splat(IV[5]) ^ count_high, - u32x16::splat(IV[6]) ^ lastblock, - u32x16::splat(IV[7]) ^ lastnode, - ]; - - round(&mut v, msg_vecs, 0); - round(&mut v, msg_vecs, 1); - round(&mut v, msg_vecs, 2); - round(&mut v, msg_vecs, 3); - round(&mut v, msg_vecs, 4); - round(&mut v, msg_vecs, 5); - round(&mut v, msg_vecs, 6); - round(&mut v, msg_vecs, 7); - round(&mut v, msg_vecs, 8); - round(&mut v, msg_vecs, 9); - - [ - h_vecs[0] ^ v[0] ^ v[8], - h_vecs[1] ^ v[1] ^ v[9], - h_vecs[2] ^ v[2] ^ v[10], - h_vecs[3] ^ v[3] ^ v[11], - h_vecs[4] ^ v[4] ^ v[12], - h_vecs[5] ^ v[5] ^ v[13], - h_vecs[6] ^ v[6] ^ v[14], - h_vecs[7] ^ v[7] ^ v[15], - ] -} - -#[cfg(test)] -mod tests { - use std::array; - use std::mem::transmute; - use std::simd::u32x16; - - use aligned::{Aligned, A64}; - - use super::{compress16, transpose_msgs, untranspose_states}; - use crate::core::vcs::blake2s_ref::compress; - - #[test] - fn compress16_works() { - let states: Aligned = - Aligned(array::from_fn(|i| array::from_fn(|j| (i + j) as u32))); - let msgs: Aligned = - Aligned(array::from_fn(|i| array::from_fn(|j| (i + j + 20) as u32))); - let count_low = 1; - let count_high = 2; - let lastblock = 3; - let lastnode = 4; - let res_unvectorized = array::from_fn(|i| { - compress( - states[i], msgs[i], count_low, count_high, lastblock, lastnode, - ) - }); - - let res_vectorized: [[u32; 8]; 16] = unsafe { - transmute(untranspose_states(compress16( - transpose_states(transmute(states)), - transpose_msgs(transmute(msgs)), - u32x16::splat(count_low), - u32x16::splat(count_high), - u32x16::splat(lastblock), - u32x16::splat(lastnode), - ))) - }; - - assert_eq!(res_vectorized, res_unvectorized); - } - - #[test] - fn untranspose_states_is_transpose_states_inverse() { - let states = array::from_fn(|i| u32x16::from(array::from_fn(|j| (i + j) as u32))); - let transposed_states = transpose_states(states); - - let untrasponsed_transposed_states = untranspose_states(transposed_states); - - assert_eq!(untrasponsed_transposed_states, states) - } - - /// Transposes states, from 8 packed words, to get 16 results, each of size 32B. - fn transpose_states(mut states: [u32x16; 8]) -> [u32x16; 8] { - // Index abc:xyzw, refers to a specific word in data as follows: - // abc - chunk index (in base 2) - // xyzw - word offset (in base 2) - // Transpose by applying 3 times the index permutation: - // abc:xyzw => wab:cxyz - // In other words, rotate the index to the right by 1. - for _ in 0..3 { - let (s0, s4) = states[0].deinterleave(states[1]); - let (s1, s5) = states[2].deinterleave(states[3]); - let (s2, s6) = states[4].deinterleave(states[5]); - let (s3, s7) = states[6].deinterleave(states[7]); - states = [s0, s1, s2, s3, s4, s5, s6, s7]; - } - - states - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/circle.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/circle.rs deleted file mode 100644 index e930f77..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/simd/circle.rs +++ /dev/null @@ -1,436 +0,0 @@ -use std::iter::zip; -use std::mem::transmute; - -use bytemuck::{cast_slice, Zeroable}; -use num_traits::One; - -use super::fft::{ifft, rfft, CACHED_FFT_LOG_SIZE}; -use super::m31::{PackedBaseField, LOG_N_LANES, N_LANES}; -use super::qm31::PackedSecureField; -use super::SimdBackend; -use crate::core::backend::simd::column::BaseColumn; -use crate::core::backend::{Col, CpuBackend}; -use crate::core::circle::{CirclePoint, Coset}; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::{Field, FieldExpOps}; -use crate::core::poly::circle::{ - CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly, PolyOps, -}; -use crate::core::poly::twiddles::TwiddleTree; -use crate::core::poly::utils::{domain_line_twiddles_from_tree, fold}; -use crate::core::poly::BitReversedOrder; - -impl SimdBackend { - // TODO(Ohad): optimize. - fn twiddle_at(mappings: &[F], mut index: usize) -> F { - debug_assert!( - (1 << mappings.len()) as usize >= index, - "Index out of bounds. mappings log len = {}, index = {index}", - mappings.len().ilog2() - ); - - let mut product = F::one(); - for &num in mappings.iter() { - if index & 1 == 1 { - product *= num; - } - index >>= 1; - if index == 0 { - break; - } - } - - product - } - - // TODO(Ohad): consider moving this to to a more general place. - // Note: CACHED_FFT_LOG_SIZE is specific to the backend. - fn generate_evaluation_mappings(point: CirclePoint, log_size: u32) -> Vec { - // Mappings are the factors used to compute the evaluation twiddle. - // Every twiddle (i) is of the form (m[0])^b_0 * (m[1])^b_1 * ... * (m[log_size - - // 1])^b_log_size. - // Where (m)_j are the mappings, and b_i is the j'th bit of i. - let mut mappings = vec![point.y, point.x]; - let mut x = point.x; - for _ in 2..log_size { - x = CirclePoint::double_x(x); - mappings.push(x); - } - - // The caller function expects the mapping in natural order. i.e. (y,x,h(x),h(h(x)),...). - // If the polynomial is large, the fft does a transpose in the middle in a granularity of 16 - // (avx512). The coefficients would then be in tranposed order of 16-sized chunks. - // i.e. (a_(n-15), a_(n-14), ..., a_(n-1), a_(n-31), ..., a_(n-16), a_(n-32), ...). - // To compute the twiddles in the correct order, we need to transpose the coprresponding - // 'transposed bits' in the mappings. The result order of the mappings would then be - // (y, x, h(x), h^2(x), h^(log_n-1)(x), h^(log_n-2)(x) ...). To avoid code - // complexity for now, we just reverse the mappings, transpose, then reverse back. - // TODO(Ohad): optimize. consider changing the caller to expect the mappings in - // reversed-tranposed order. - if log_size > CACHED_FFT_LOG_SIZE { - mappings.reverse(); - let n = mappings.len(); - let n0 = (n - LOG_N_LANES as usize) / 2; - let n1 = (n - LOG_N_LANES as usize + 1) / 2; - let (ab, c) = mappings.split_at_mut(n1); - let (a, _b) = ab.split_at_mut(n0); - // Swap content of a,c. - a.swap_with_slice(&mut c[0..n0]); - mappings.reverse(); - } - - mappings - } - - // Generates twiddle steps for efficiently computing the twiddles. - // steps[i] = t_i/(t_0*t_1*...*t_i-1). - fn twiddle_steps(mappings: &[F]) -> Vec - where - F: FieldExpOps, - { - let mut denominators: Vec = vec![mappings[0]]; - - for i in 1..mappings.len() { - denominators.push(denominators[i - 1] * mappings[i]); - } - - let mut denom_inverses = vec![F::zero(); denominators.len()]; - F::batch_inverse(&denominators, &mut denom_inverses); - - let mut steps = vec![mappings[0]]; - - mappings - .iter() - .skip(1) - .zip(denom_inverses.iter()) - .for_each(|(&m, &d)| { - steps.push(m * d); - }); - steps.push(F::one()); - steps - } - - // Advances the twiddle by multiplying it by the next step. e.g: - // If idx(t) = 0b100..1010 , then f(t) = t * step[0] - // If idx(t) = 0b100..0111 , then f(t) = t * step[3] - fn advance_twiddle(twiddle: F, steps: &[F], curr_idx: usize) -> F { - twiddle * steps[curr_idx.trailing_ones() as usize] - } -} - -// TODO(spapini): Everything is returned in redundant representation, where values can also be P. -// Decide if and when it's ok and what to do if it's not. -impl PolyOps for SimdBackend { - // The twiddles type is i32, and not BaseField. This is because the fast AVX mul implementation - // requries one of the numbers to be shifted left by 1 bit. This is not a reduced - // representation of the field. - type Twiddles = Vec; - - fn new_canonical_ordered( - coset: CanonicCoset, - values: Col, - ) -> CircleEvaluation { - // TODO(spapini): Optimize. - let eval = CpuBackend::new_canonical_ordered(coset, values.into_cpu_vec()); - CircleEvaluation::new( - eval.domain, - Col::::from_iter(eval.values), - ) - } - - fn interpolate( - eval: CircleEvaluation, - twiddles: &TwiddleTree, - ) -> CirclePoly { - let mut values = eval.values; - let log_size = values.length.ilog2(); - - let twiddles = domain_line_twiddles_from_tree(eval.domain, &twiddles.itwiddles); - - // Safe because [PackedBaseField] is aligned on 64 bytes. - unsafe { - ifft::ifft( - transmute(values.data.as_mut_ptr()), - &twiddles, - log_size as usize, - ); - } - - // TODO(spapini): Fuse this multiplication / rotation. - let inv = PackedBaseField::broadcast(BaseField::from(eval.domain.size()).inverse()); - values.data.iter_mut().for_each(|x| *x *= inv); - - CirclePoly::new(values) - } - - fn eval_at_point(poly: &CirclePoly, point: CirclePoint) -> SecureField { - // If the polynomial is small, fallback to evaluate directly. - // TODO(Ohad): it's possible to avoid falling back. Consider fixing. - if poly.log_size() <= 8 { - return slow_eval_at_point(poly, point); - } - - let mappings = Self::generate_evaluation_mappings(point, poly.log_size()); - - // 8 lowest mappings produce the first 2^8 twiddles. Separate to optimize each calculation. - let (map_low, map_high) = mappings.split_at(4); - let twiddle_lows = - PackedSecureField::from_array(std::array::from_fn(|i| Self::twiddle_at(map_low, i))); - let (map_mid, map_high) = map_high.split_at(4); - let twiddle_mids = - PackedSecureField::from_array(std::array::from_fn(|i| Self::twiddle_at(map_mid, i))); - - // Compute the high twiddle steps. - let twiddle_steps = Self::twiddle_steps(map_high); - - // Every twiddle is a product of mappings that correspond to '1's in the bit representation - // of the current index. For every 2^n alligned chunk of 2^n elements, the twiddle - // array is the same, denoted twiddle_low. Use this to compute sums of (coeff * - // twiddle_high) mod 2^n, then multiply by twiddle_low, and sum to get the final result. - let mut sum = PackedSecureField::zeroed(); - let mut twiddle_high = SecureField::one(); - for (i, coeff_chunk) in poly.coeffs.data.array_chunks::().enumerate() { - // For every chunk of 2 ^ 4 * 2 ^ 4 = 2 ^ 8 elements, the twiddle high is the same. - // Multiply it by every mid twiddle factor to get the factors for the current chunk. - let high_twiddle_factors = - (PackedSecureField::broadcast(twiddle_high) * twiddle_mids).to_array(); - - // Sum the coefficients multiplied by each corrseponsing twiddle. Result is effectivley - // an array[16] where the value at index 'i' is the sum of all coefficients at indices - // that are i mod 16. - for (&packed_coeffs, mid_twiddle) in zip(coeff_chunk, high_twiddle_factors) { - sum += PackedSecureField::broadcast(mid_twiddle) * packed_coeffs; - } - - // Advance twiddle high. - twiddle_high = Self::advance_twiddle(twiddle_high, &twiddle_steps, i); - } - - (sum * twiddle_lows).pointwise_sum() - } - - fn extend(poly: &CirclePoly, log_size: u32) -> CirclePoly { - // TODO(spapini): Optimize or get rid of extend. - poly.evaluate(CanonicCoset::new(log_size).circle_domain()) - .interpolate() - } - - fn evaluate( - poly: &CirclePoly, - domain: CircleDomain, - twiddles: &TwiddleTree, - ) -> CircleEvaluation { - // TODO(spapini): Precompute twiddles. - // TODO(spapini): Handle small cases. - let log_size = domain.log_size(); - let fft_log_size = poly.log_size(); - assert!( - log_size >= fft_log_size, - "Can only evaluate on larger domains" - ); - - let twiddles = domain_line_twiddles_from_tree(domain, &twiddles.twiddles); - - // Evaluate on a big domains by evaluating on several subdomains. - let log_subdomains = log_size - fft_log_size; - - // Allocate the destination buffer without initializing. - let mut values = Vec::with_capacity(domain.size() >> LOG_N_LANES); - #[allow(clippy::uninit_vec)] - unsafe { - values.set_len(domain.size() >> LOG_N_LANES) - }; - - for i in 0..(1 << log_subdomains) { - // The subdomain twiddles are a slice of the large domain twiddles. - let subdomain_twiddles = (0..(fft_log_size - 1)) - .map(|layer_i| { - &twiddles[layer_i as usize] - [i << (fft_log_size - 2 - layer_i)..(i + 1) << (fft_log_size - 2 - layer_i)] - }) - .collect::>(); - - // FFT from the coefficients buffer to the values chunk. - unsafe { - rfft::fft( - transmute(poly.coeffs.data.as_ptr()), - transmute( - values[i << (fft_log_size - LOG_N_LANES) - ..(i + 1) << (fft_log_size - LOG_N_LANES)] - .as_mut_ptr(), - ), - &subdomain_twiddles, - fft_log_size as usize, - ); - } - } - - CircleEvaluation::new( - domain, - BaseColumn { - data: values, - length: domain.size(), - }, - ) - } - - fn precompute_twiddles(coset: Coset) -> TwiddleTree { - let mut twiddles = Vec::with_capacity(coset.size()); - let mut itwiddles = Vec::with_capacity(coset.size()); - - // TODO(spapini): Optimize. - for layer in &rfft::get_twiddle_dbls(coset) { - twiddles.extend(layer); - } - // Pad by any value, to make the size a power of 2. - twiddles.push(1); - assert_eq!(twiddles.len(), coset.size()); - for layer in &ifft::get_itwiddle_dbls(coset) { - itwiddles.extend(layer); - } - // Pad by any value, to make the size a power of 2. - itwiddles.push(1); - assert_eq!(itwiddles.len(), coset.size()); - - TwiddleTree { - root_coset: coset, - twiddles, - itwiddles, - } - } -} - -fn slow_eval_at_point( - poly: &CirclePoly, - point: CirclePoint, -) -> SecureField { - let mut mappings = vec![point.y, point.x]; - let mut x = point.x; - for _ in 2..poly.log_size() { - x = CirclePoint::double_x(x); - mappings.push(x); - } - mappings.reverse(); - - // If the polynomial is large, the fft does a transpose in the middle. - if poly.log_size() > CACHED_FFT_LOG_SIZE { - let n = mappings.len(); - let n0 = (n - LOG_N_LANES as usize) / 2; - let n1 = (n - LOG_N_LANES as usize + 1) / 2; - let (ab, c) = mappings.split_at_mut(n1); - let (a, _b) = ab.split_at_mut(n0); - // Swap content of a,c. - a.swap_with_slice(&mut c[0..n0]); - } - fold(cast_slice::<_, BaseField>(&poly.coeffs.data), &mappings) -} - -#[cfg(test)] -mod tests { - use rand::rngs::SmallRng; - use rand::{Rng, SeedableRng}; - - use crate::core::backend::simd::circle::slow_eval_at_point; - use crate::core::backend::simd::fft::{CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE}; - use crate::core::backend::simd::SimdBackend; - use crate::core::backend::Column; - use crate::core::circle::CirclePoint; - use crate::core::fields::m31::BaseField; - use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, CirclePoly, PolyOps}; - use crate::core::poly::{BitReversedOrder, NaturalOrder}; - - #[test] - fn test_interpolate_and_eval() { - for log_size in MIN_FFT_LOG_SIZE..CACHED_FFT_LOG_SIZE + 4 { - let domain = CanonicCoset::new(log_size).circle_domain(); - let evaluation = CircleEvaluation::::new( - domain, - (0..1 << log_size).map(BaseField::from).collect(), - ); - - let poly = evaluation.clone().interpolate(); - let evaluation2 = poly.evaluate(domain); - - assert_eq!(evaluation.values.to_cpu(), evaluation2.values.to_cpu()); - } - } - - #[test] - fn test_eval_extension() { - for log_size in MIN_FFT_LOG_SIZE..CACHED_FFT_LOG_SIZE + 2 { - let domain = CanonicCoset::new(log_size).circle_domain(); - let domain_ext = CanonicCoset::new(log_size + 2).circle_domain(); - let evaluation = CircleEvaluation::::new( - domain, - (0..1 << log_size).map(BaseField::from).collect(), - ); - let poly = evaluation.clone().interpolate(); - - let evaluation2 = poly.evaluate(domain_ext); - - assert_eq!( - poly.extend(log_size + 2).coeffs.to_cpu(), - evaluation2.interpolate().coeffs.to_cpu() - ); - } - } - - #[test] - fn test_eval_at_point() { - for log_size in MIN_FFT_LOG_SIZE + 1..CACHED_FFT_LOG_SIZE + 4 { - let domain = CanonicCoset::new(log_size).circle_domain(); - let evaluation = CircleEvaluation::::new( - domain, - (0..1 << log_size).map(BaseField::from).collect(), - ); - let poly = evaluation.bit_reverse().interpolate(); - for i in [0, 1, 3, 1 << (log_size - 1), 1 << (log_size - 2)] { - let p = domain.at(i); - - let eval = poly.eval_at_point(p.into_ef()); - - assert_eq!( - eval, - BaseField::from(i).into(), - "log_size={log_size}, i={i}" - ); - } - } - } - - #[test] - fn test_circle_poly_extend() { - for log_size in MIN_FFT_LOG_SIZE..CACHED_FFT_LOG_SIZE + 2 { - let poly = - CirclePoly::::new((0..1 << log_size).map(BaseField::from).collect()); - let eval0 = poly.evaluate(CanonicCoset::new(log_size + 2).circle_domain()); - - let eval1 = poly - .extend(log_size + 2) - .evaluate(CanonicCoset::new(log_size + 2).circle_domain()); - - assert_eq!(eval0.values.to_cpu(), eval1.values.to_cpu()); - } - } - - #[test] - fn test_eval_securefield() { - let mut rng = SmallRng::seed_from_u64(0); - for log_size in MIN_FFT_LOG_SIZE..CACHED_FFT_LOG_SIZE + 2 { - let domain = CanonicCoset::new(log_size).circle_domain(); - let evaluation = CircleEvaluation::::new( - domain, - (0..1 << log_size).map(BaseField::from).collect(), - ); - let poly = evaluation.bit_reverse().interpolate(); - let x = rng.gen(); - let y = rng.gen(); - let p = CirclePoint { x, y }; - - let eval = PolyOps::eval_at_point(&poly, p); - - assert_eq!(eval, slow_eval_at_point(&poly, p), "log_size = {log_size}"); - } - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/cm31.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/cm31.rs deleted file mode 100644 index 31aba0a..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/simd/cm31.rs +++ /dev/null @@ -1,230 +0,0 @@ -use std::array; -use std::ops::{Add, Mul, MulAssign, Neg, Sub}; - -use bytemuck::{Pod, Zeroable}; -use num_traits::{One, Zero}; - -use super::m31::{PackedM31, N_LANES}; -use crate::core::fields::cm31::CM31; -use crate::core::fields::FieldExpOps; - -/// SIMD implementation of [`CM31`]. -#[derive(Copy, Clone, Debug)] -pub struct PackedCM31(pub [PackedM31; 2]); - -impl PackedCM31 { - /// Constructs a new instance with all vector elements set to `value`. - pub fn broadcast(value: CM31) -> Self { - Self([PackedM31::broadcast(value.0), PackedM31::broadcast(value.1)]) - } - - /// Returns all `a` values such that each vector element is represented as `a + bi`. - pub fn a(&self) -> PackedM31 { - self.0[0] - } - - /// Returns all `b` values such that each vector element is represented as `a + bi`. - pub fn b(&self) -> PackedM31 { - self.0[1] - } - - pub fn to_array(&self) -> [CM31; N_LANES] { - let a = self.a().to_array(); - let b = self.b().to_array(); - array::from_fn(|i| CM31(a[i], b[i])) - } - - pub fn from_array(values: [CM31; N_LANES]) -> Self { - Self([ - PackedM31::from_array(values.map(|v| v.0)), - PackedM31::from_array(values.map(|v| v.1)), - ]) - } - - /// Interleaves two vectors. - pub fn interleave(self, other: Self) -> (Self, Self) { - let Self([a_evens, b_evens]) = self; - let Self([a_odds, b_odds]) = other; - let (a_lhs, a_rhs) = a_evens.interleave(a_odds); - let (b_lhs, b_rhs) = b_evens.interleave(b_odds); - (Self([a_lhs, b_lhs]), Self([a_rhs, b_rhs])) - } - - /// Deinterleaves two vectors. - pub fn deinterleave(self, other: Self) -> (Self, Self) { - let Self([a_self, b_self]) = self; - let Self([a_other, b_other]) = other; - let (a_evens, a_odds) = a_self.deinterleave(a_other); - let (b_evens, b_odds) = b_self.deinterleave(b_other); - (Self([a_evens, b_evens]), Self([a_odds, b_odds])) - } - - /// Doubles each element in the vector. - pub fn double(self) -> Self { - let Self([a, b]) = self; - Self([a.double(), b.double()]) - } -} - -impl Add for PackedCM31 { - type Output = Self; - - fn add(self, rhs: Self) -> Self::Output { - Self([self.a() + rhs.a(), self.b() + rhs.b()]) - } -} - -impl Sub for PackedCM31 { - type Output = Self; - - fn sub(self, rhs: Self) -> Self::Output { - Self([self.a() - rhs.a(), self.b() - rhs.b()]) - } -} - -impl Mul for PackedCM31 { - type Output = Self; - - fn mul(self, rhs: Self) -> Self::Output { - // Compute using Karatsuba. - let ac = self.a() * rhs.a(); - let bd = self.b() * rhs.b(); - // Computes (a + b) * (c + d). - let ab_t_cd = (self.a() + self.b()) * (rhs.a() + rhs.b()); - // (ac - bd) + (ad + bc)i. - Self([ac - bd, ab_t_cd - ac - bd]) - } -} - -impl Zero for PackedCM31 { - fn zero() -> Self { - Self([PackedM31::zero(), PackedM31::zero()]) - } - - fn is_zero(&self) -> bool { - self.a().is_zero() && self.b().is_zero() - } -} - -unsafe impl Pod for PackedCM31 {} - -unsafe impl Zeroable for PackedCM31 { - fn zeroed() -> Self { - unsafe { core::mem::zeroed() } - } -} - -impl One for PackedCM31 { - fn one() -> Self { - Self([PackedM31::one(), PackedM31::zero()]) - } -} - -impl MulAssign for PackedCM31 { - fn mul_assign(&mut self, rhs: Self) { - *self = *self * rhs; - } -} - -impl FieldExpOps for PackedCM31 { - fn inverse(&self) -> Self { - assert!(!self.is_zero(), "0 has no inverse"); - // 1 / (a + bi) = (a - bi) / (a^2 + b^2). - Self([self.a(), -self.b()]) * (self.a().square() + self.b().square()).inverse() - } -} - -impl Add for PackedCM31 { - type Output = Self; - - fn add(self, rhs: PackedM31) -> Self::Output { - Self([self.a() + rhs, self.b()]) - } -} - -impl Sub for PackedCM31 { - type Output = Self; - - fn sub(self, rhs: PackedM31) -> Self::Output { - let Self([a, b]) = self; - Self([a - rhs, b]) - } -} - -impl Mul for PackedCM31 { - type Output = Self; - - fn mul(self, rhs: PackedM31) -> Self::Output { - let Self([a, b]) = self; - Self([a * rhs, b * rhs]) - } -} - -impl Neg for PackedCM31 { - type Output = Self; - - fn neg(self) -> Self::Output { - let Self([a, b]) = self; - Self([-a, -b]) - } -} - -#[cfg(test)] -mod tests { - use std::array; - - use rand::rngs::SmallRng; - use rand::{Rng, SeedableRng}; - - use crate::core::backend::simd::cm31::PackedCM31; - - #[test] - fn addition_works() { - let mut rng = SmallRng::seed_from_u64(0); - let lhs = rng.gen(); - let rhs = rng.gen(); - let packed_lhs = PackedCM31::from_array(lhs); - let packed_rhs = PackedCM31::from_array(rhs); - - let res = packed_lhs + packed_rhs; - - assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] + rhs[i])); - } - - #[test] - fn subtraction_works() { - let mut rng = SmallRng::seed_from_u64(0); - let lhs = rng.gen(); - let rhs = rng.gen(); - let packed_lhs = PackedCM31::from_array(lhs); - let packed_rhs = PackedCM31::from_array(rhs); - - let res = packed_lhs - packed_rhs; - - assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] - rhs[i])); - } - - #[test] - fn multiplication_works() { - let mut rng = SmallRng::seed_from_u64(0); - let lhs = rng.gen(); - let rhs = rng.gen(); - let packed_lhs = PackedCM31::from_array(lhs); - let packed_rhs = PackedCM31::from_array(rhs); - - let res = packed_lhs * packed_rhs; - - assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] * rhs[i])); - } - - #[test] - fn negation_works() { - let mut rng = SmallRng::seed_from_u64(0); - let values = rng.gen(); - let packed_values = PackedCM31::from_array(values); - - let res = -packed_values; - - assert_eq!(res.to_array(), values.map(|v| -v)); - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/column.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/column.rs deleted file mode 100644 index 6486940..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/simd/column.rs +++ /dev/null @@ -1,656 +0,0 @@ -use std::iter::zip; -use std::{array, mem}; - -use bytemuck::allocation::cast_vec; -use bytemuck::{cast_slice, cast_slice_mut, Zeroable}; -use itertools::{izip, Itertools}; -use num_traits::Zero; - -use super::cm31::PackedCM31; -use super::m31::{PackedBaseField, N_LANES}; -use super::qm31::{PackedQM31, PackedSecureField}; -use super::very_packed_m31::{VeryPackedBaseField, VeryPackedSecureField, N_VERY_PACKED_ELEMS}; -use super::SimdBackend; -use crate::core::backend::{Column, CpuBackend}; -use crate::core::fields::cm31::CM31; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::secure_column::{SecureColumnByCoords, SECURE_EXTENSION_DEGREE}; -use crate::core::fields::{FieldExpOps, FieldOps}; - -impl FieldOps for SimdBackend { - fn batch_inverse(column: &BaseColumn, dst: &mut BaseColumn) { - PackedBaseField::batch_inverse(&column.data, &mut dst.data); - } -} - -impl FieldOps for SimdBackend { - fn batch_inverse(column: &SecureColumn, dst: &mut SecureColumn) { - PackedSecureField::batch_inverse(&column.data, &mut dst.data); - } -} - -/// An efficient structure for storing and operating on a arbitrary number of [`BaseField`] values. -#[derive(Clone, Debug)] -pub struct BaseColumn { - pub data: Vec, - /// The number of [`BaseField`]s in the vector. - pub length: usize, -} - -impl BaseColumn { - /// Extracts a slice containing the entire vector of [`BaseField`]s. - pub fn as_slice(&self) -> &[BaseField] { - &cast_slice(&self.data)[..self.length] - } - - /// Extracts a mutable slice containing the entire vector of [`BaseField`]s. - pub fn as_mut_slice(&mut self) -> &mut [BaseField] { - &mut cast_slice_mut(&mut self.data)[..self.length] - } - - pub fn into_cpu_vec(mut self) -> Vec { - let capacity = self.data.capacity() * N_LANES; - let length = self.length; - let ptr = self.data.as_mut_ptr() as *mut BaseField; - let res = unsafe { Vec::from_raw_parts(ptr, length, capacity) }; - mem::forget(self); - res - } - - /// Returns a vector of `BaseColumnMutSlice`s, each mutably owning - /// `chunk_size` `PackedBaseField`s (i.e, `chuck_size` * `N_LANES` elements). - pub fn chunks_mut(&mut self, chunk_size: usize) -> Vec> { - self.data - .chunks_mut(chunk_size) - .map(BaseColumnMutSlice) - .collect_vec() - } - - pub fn into_secure_column(self) -> SecureColumn { - let length = self.len(); - let data = self.data.into_iter().map(PackedSecureField::from).collect(); - SecureColumn { data, length } - } -} - -impl Column for BaseColumn { - fn zeros(length: usize) -> Self { - let data = vec![PackedBaseField::zeroed(); length.div_ceil(N_LANES)]; - Self { data, length } - } - - #[allow(clippy::uninit_vec)] - unsafe fn uninitialized(length: usize) -> Self { - let mut data = Vec::with_capacity(length.div_ceil(N_LANES)); - data.set_len(length.div_ceil(N_LANES)); - Self { data, length } - } - - fn to_cpu(&self) -> Vec { - self.as_slice().to_vec() - } - - fn len(&self) -> usize { - self.length - } - - fn at(&self, index: usize) -> BaseField { - self.data[index / N_LANES].to_array()[index % N_LANES] - } - - fn set(&mut self, index: usize, value: BaseField) { - let mut packed = self.data[index / N_LANES].to_array(); - packed[index % N_LANES] = value; - self.data[index / N_LANES] = PackedBaseField::from_array(packed) - } -} - -impl FromIterator for BaseColumn { - fn from_iter>(iter: I) -> Self { - let mut chunks = iter.into_iter().array_chunks(); - let mut data = (&mut chunks).map(PackedBaseField::from_array).collect_vec(); - let mut length = data.len() * N_LANES; - - if let Some(remainder) = chunks.into_remainder() { - if !remainder.is_empty() { - length += remainder.len(); - let mut last = [BaseField::zero(); N_LANES]; - last[..remainder.len()].copy_from_slice(remainder.as_slice()); - data.push(PackedBaseField::from_array(last)); - } - } - - Self { data, length } - } -} - -// A efficient structure for storing and operating on a arbitrary number of [`SecureField`] values. -#[derive(Clone, Debug)] -pub struct CM31Column { - pub data: Vec, - pub length: usize, -} - -impl Column for CM31Column { - fn zeros(length: usize) -> Self { - Self { - data: vec![PackedCM31::zeroed(); length.div_ceil(N_LANES)], - length, - } - } - - #[allow(clippy::uninit_vec)] - unsafe fn uninitialized(length: usize) -> Self { - let mut data = Vec::with_capacity(length.div_ceil(N_LANES)); - data.set_len(length.div_ceil(N_LANES)); - Self { data, length } - } - - fn to_cpu(&self) -> Vec { - self.data - .iter() - .flat_map(|x| x.to_array()) - .take(self.length) - .collect() - } - - fn len(&self) -> usize { - self.length - } - - fn at(&self, index: usize) -> CM31 { - self.data[index / N_LANES].to_array()[index % N_LANES] - } - - fn set(&mut self, index: usize, value: CM31) { - let mut packed = self.data[index / N_LANES].to_array(); - packed[index % N_LANES] = value; - self.data[index / N_LANES] = PackedCM31::from_array(packed) - } -} - -impl FromIterator for CM31Column { - fn from_iter>(iter: I) -> Self { - let mut chunks = iter.into_iter().array_chunks(); - let mut data = (&mut chunks).map(PackedCM31::from_array).collect_vec(); - let mut length = data.len() * N_LANES; - - if let Some(remainder) = chunks.into_remainder() { - if !remainder.is_empty() { - length += remainder.len(); - let mut last = [CM31::zero(); N_LANES]; - last[..remainder.len()].copy_from_slice(remainder.as_slice()); - data.push(PackedCM31::from_array(last)); - } - } - - Self { data, length } - } -} - -impl FromIterator for CM31Column { - fn from_iter>(iter: I) -> Self { - let data = (&mut iter.into_iter()).collect_vec(); - let length = data.len() * N_LANES; - - Self { data, length } - } -} - -/// A mutable slice of a BaseColumn. -pub struct BaseColumnMutSlice<'a>(pub &'a mut [PackedBaseField]); - -impl<'a> BaseColumnMutSlice<'a> { - pub fn at(&self, index: usize) -> BaseField { - self.0[index / N_LANES].to_array()[index % N_LANES] - } - - pub fn set(&mut self, index: usize, value: BaseField) { - let mut packed = self.0[index / N_LANES].to_array(); - packed[index % N_LANES] = value; - self.0[index / N_LANES] = PackedBaseField::from_array(packed) - } -} - -/// An efficient structure for storing and operating on a arbitrary number of [`SecureField`] -/// values. -#[derive(Clone, Debug)] -pub struct SecureColumn { - pub data: Vec, - /// The number of [`SecureField`]s in the vector. - pub length: usize, -} - -impl SecureColumn { - // Separates a single column of `PackedSecureField` elements into `SECURE_EXTENSION_DEGREE` many - // `PackedBaseField` coordinate columns. - pub fn into_secure_column_by_coords(self) -> SecureColumnByCoords { - if self.len() < N_LANES { - return self.to_cpu().into_iter().collect(); - } - - let length = self.length; - let packed_length = self.data.len(); - let mut columns = array::from_fn(|_| Vec::with_capacity(packed_length)); - - for v in self.data { - let packed_coords = v.into_packed_m31s(); - zip(&mut columns, packed_coords).for_each(|(col, packed_coord)| col.push(packed_coord)); - } - - SecureColumnByCoords { - columns: columns.map(|col| BaseColumn { data: col, length }), - } - } -} - -impl Column for SecureColumn { - fn zeros(length: usize) -> Self { - Self { - data: vec![PackedSecureField::zeroed(); length.div_ceil(N_LANES)], - length, - } - } - - #[allow(clippy::uninit_vec)] - unsafe fn uninitialized(length: usize) -> Self { - let mut data = Vec::with_capacity(length.div_ceil(N_LANES)); - data.set_len(length.div_ceil(N_LANES)); - Self { data, length } - } - - fn to_cpu(&self) -> Vec { - self.data - .iter() - .flat_map(|x| x.to_array()) - .take(self.length) - .collect() - } - - fn len(&self) -> usize { - self.length - } - - fn at(&self, index: usize) -> SecureField { - self.data[index / N_LANES].to_array()[index % N_LANES] - } - - fn set(&mut self, index: usize, value: SecureField) { - let mut packed = self.data[index / N_LANES].to_array(); - packed[index % N_LANES] = value; - self.data[index / N_LANES] = PackedSecureField::from_array(packed) - } -} - -impl FromIterator for SecureColumn { - fn from_iter>(iter: I) -> Self { - let mut chunks = iter.into_iter().array_chunks(); - let mut data = (&mut chunks) - .map(PackedSecureField::from_array) - .collect_vec(); - let mut length = data.len() * N_LANES; - - if let Some(remainder) = chunks.into_remainder() { - if !remainder.is_empty() { - length += remainder.len(); - let mut last = [SecureField::zero(); N_LANES]; - last[..remainder.len()].copy_from_slice(remainder.as_slice()); - data.push(PackedSecureField::from_array(last)); - } - } - - Self { data, length } - } -} - -impl FromIterator for SecureColumn { - fn from_iter>(iter: I) -> Self { - let data = iter.into_iter().collect_vec(); - let length = data.len() * N_LANES; - Self { data, length } - } -} - -/// A mutable slice of a SecureColumnByCoords. -pub struct SecureColumnByCoordsMutSlice<'a>(pub [BaseColumnMutSlice<'a>; SECURE_EXTENSION_DEGREE]); - -impl<'a> SecureColumnByCoordsMutSlice<'a> { - /// # Safety - /// - /// `vec_index` must be a valid index. - pub unsafe fn packed_at(&self, vec_index: usize) -> PackedSecureField { - PackedQM31([ - PackedCM31([ - *self.0[0].0.get_unchecked(vec_index), - *self.0[1].0.get_unchecked(vec_index), - ]), - PackedCM31([ - *self.0[2].0.get_unchecked(vec_index), - *self.0[3].0.get_unchecked(vec_index), - ]), - ]) - } - - /// # Safety - /// - /// `vec_index` must be a valid index. - pub unsafe fn set_packed(&mut self, vec_index: usize, value: PackedSecureField) { - let PackedQM31([PackedCM31([a, b]), PackedCM31([c, d])]) = value; - *self.0[0].0.get_unchecked_mut(vec_index) = a; - *self.0[1].0.get_unchecked_mut(vec_index) = b; - *self.0[2].0.get_unchecked_mut(vec_index) = c; - *self.0[3].0.get_unchecked_mut(vec_index) = d; - } -} - -impl SecureColumnByCoords { - pub fn packed_len(&self) -> usize { - self.columns[0].data.len() - } - - /// # Safety - /// - /// `vec_index` must be a valid index. - pub unsafe fn packed_at(&self, vec_index: usize) -> PackedSecureField { - PackedQM31([ - PackedCM31([ - *self.columns[0].data.get_unchecked(vec_index), - *self.columns[1].data.get_unchecked(vec_index), - ]), - PackedCM31([ - *self.columns[2].data.get_unchecked(vec_index), - *self.columns[3].data.get_unchecked(vec_index), - ]), - ]) - } - - /// # Safety - /// - /// `vec_index` must be a valid index. - pub unsafe fn set_packed(&mut self, vec_index: usize, value: PackedSecureField) { - let PackedQM31([PackedCM31([a, b]), PackedCM31([c, d])]) = value; - *self.columns[0].data.get_unchecked_mut(vec_index) = a; - *self.columns[1].data.get_unchecked_mut(vec_index) = b; - *self.columns[2].data.get_unchecked_mut(vec_index) = c; - *self.columns[3].data.get_unchecked_mut(vec_index) = d; - } - - pub fn to_vec(&self) -> Vec { - izip!( - self.columns[0].to_cpu(), - self.columns[1].to_cpu(), - self.columns[2].to_cpu(), - self.columns[3].to_cpu(), - ) - .map(|(a, b, c, d)| SecureField::from_m31_array([a, b, c, d])) - .collect() - } - - /// Returns a vector of `SecureColumnByCoordsMutSlice`s, each mutably owning - /// `SECURE_EXTENSION_DEGREE` slices of `chunk_size` `PackedBaseField`s - /// (i.e, `chuck_size` * `N_LANES` secure field elements, by coordinates). - pub fn chunks_mut(&mut self, chunk_size: usize) -> Vec> { - let [a, b, c, d] = self - .columns - .get_many_mut([0, 1, 2, 3]) - .unwrap() - .map(|x| x.chunks_mut(chunk_size)); - izip!(a, b, c, d) - .map(|(a, b, c, d)| SecureColumnByCoordsMutSlice([a, b, c, d])) - .collect_vec() - } -} - -impl FromIterator for SecureColumnByCoords { - fn from_iter>(iter: I) -> Self { - let cpu_col = SecureColumnByCoords::::from_iter(iter); - let columns = cpu_col.columns.map(|col| col.into_iter().collect()); - SecureColumnByCoords { columns } - } -} - -#[derive(Clone, Debug)] -pub struct VeryPackedBaseColumn { - pub data: Vec, - /// The number of [`BaseField`]s in the vector. - pub length: usize, -} - -impl VeryPackedBaseColumn { - /// Transforms a `&BaseColumn` to a `&VeryPackedBaseColumn`. - /// # Safety - /// - /// The resulting pointer does not update the underlying `data`'s length. - pub unsafe fn transform_under_ref(value: &BaseColumn) -> &Self { - &*(std::ptr::addr_of!(*value) as *const VeryPackedBaseColumn) - } -} - -impl From for VeryPackedBaseColumn { - fn from(value: BaseColumn) -> Self { - Self { - data: cast_vec(value.data), - length: value.length, - } - } -} - -impl From for BaseColumn { - fn from(value: VeryPackedBaseColumn) -> Self { - Self { - data: cast_vec(value.data), - length: value.length, - } - } -} - -impl FromIterator for VeryPackedBaseColumn { - fn from_iter>(iter: I) -> Self { - BaseColumn::from_iter(iter).into() - } -} - -impl Column for VeryPackedBaseColumn { - fn zeros(length: usize) -> Self { - BaseColumn::zeros(length).into() - } - - #[allow(clippy::uninit_vec)] - unsafe fn uninitialized(length: usize) -> Self { - BaseColumn::uninitialized(length).into() - } - - fn to_cpu(&self) -> Vec { - todo!() - } - - fn len(&self) -> usize { - self.length - } - - fn at(&self, index: usize) -> BaseField { - let chunk_size = N_LANES * N_VERY_PACKED_ELEMS; - self.data[index / chunk_size].to_array()[index % chunk_size] - } - - fn set(&mut self, index: usize, value: BaseField) { - let chunk_size = N_LANES * N_VERY_PACKED_ELEMS; - let mut packed = self.data[index / chunk_size].to_array(); - packed[index % chunk_size] = value; - self.data[index / chunk_size] = VeryPackedBaseField::from_array(packed) - } -} - -#[derive(Clone, Debug)] -pub struct VeryPackedSecureColumnByCoords { - pub columns: [VeryPackedBaseColumn; SECURE_EXTENSION_DEGREE], -} - -impl From> for VeryPackedSecureColumnByCoords { - fn from(value: SecureColumnByCoords) -> Self { - Self { - columns: value - .columns - .into_iter() - .map(VeryPackedBaseColumn::from) - .collect_vec() - .try_into() - .unwrap(), - } - } -} - -impl From for SecureColumnByCoords { - fn from(value: VeryPackedSecureColumnByCoords) -> Self { - Self { - columns: value - .columns - .into_iter() - .map(BaseColumn::from) - .collect_vec() - .try_into() - .unwrap(), - } - } -} - -impl VeryPackedSecureColumnByCoords { - pub fn packed_len(&self) -> usize { - self.columns[0].data.len() - } - - /// # Safety - /// - /// `vec_index` must be a valid index. - pub unsafe fn packed_at(&self, vec_index: usize) -> VeryPackedSecureField { - VeryPackedSecureField::from_fn(|i| { - PackedQM31([ - PackedCM31([ - self.columns[0].data.get_unchecked(vec_index).0[i], - self.columns[1].data.get_unchecked(vec_index).0[i], - ]), - PackedCM31([ - self.columns[2].data.get_unchecked(vec_index).0[i], - self.columns[3].data.get_unchecked(vec_index).0[i], - ]), - ]) - }) - } - - /// # Safety - /// - /// `vec_index` must be a valid index. - pub unsafe fn set_packed(&mut self, vec_index: usize, value: VeryPackedSecureField) { - for i in 0..N_VERY_PACKED_ELEMS { - let PackedQM31([PackedCM31([a, b]), PackedCM31([c, d])]) = value.0[i]; - self.columns[0].data.get_unchecked_mut(vec_index).0[i] = a; - self.columns[1].data.get_unchecked_mut(vec_index).0[i] = b; - self.columns[2].data.get_unchecked_mut(vec_index).0[i] = c; - self.columns[3].data.get_unchecked_mut(vec_index).0[i] = d; - } - } - - pub fn to_vec(&self) -> Vec { - izip!( - self.columns[0].to_cpu(), - self.columns[1].to_cpu(), - self.columns[2].to_cpu(), - self.columns[3].to_cpu(), - ) - .map(|(a, b, c, d)| SecureField::from_m31_array([a, b, c, d])) - .collect() - } - - /// Transforms a `&mut SecureColumnByCoords` to a - /// `&mut VeryPackedSecureColumnByCoords`. - /// - /// # Safety - /// - /// The resulting pointer does not update the underlying columns' `data`'s lengths. - pub unsafe fn transform_under_mut(value: &mut SecureColumnByCoords) -> &mut Self { - &mut *(std::ptr::addr_of!(*value) as *mut VeryPackedSecureColumnByCoords) - } -} - -#[cfg(test)] -mod tests { - use std::array; - - use rand::rngs::SmallRng; - use rand::{Rng, SeedableRng}; - - use super::BaseColumn; - use crate::core::backend::simd::column::SecureColumn; - use crate::core::backend::simd::m31::N_LANES; - use crate::core::backend::simd::qm31::PackedQM31; - use crate::core::backend::Column; - use crate::core::fields::m31::BaseField; - use crate::core::fields::qm31::SecureField; - use crate::core::fields::secure_column::SecureColumnByCoords; - - #[test] - fn base_field_vec_from_iter_works() { - let values: [BaseField; 30] = array::from_fn(BaseField::from); - - let res = values.into_iter().collect::(); - - assert_eq!(res.to_cpu(), values); - } - - #[test] - fn secure_field_vec_from_iter_works() { - let mut rng = SmallRng::seed_from_u64(0); - let values: [SecureField; 30] = rng.gen(); - - let res = values.into_iter().collect::(); - - assert_eq!(res.to_cpu(), values); - } - - #[test] - fn test_base_column_chunks_mut() { - let values: [BaseField; N_LANES * 7] = array::from_fn(BaseField::from); - let mut col = values.into_iter().collect::(); - - const CHUNK_SIZE: usize = 2; - let mut chunks = col.chunks_mut(CHUNK_SIZE); - chunks[2].set(19, BaseField::from(1234)); - chunks[3].set(1, BaseField::from(5678)); - - assert_eq!(col.at(2 * CHUNK_SIZE * N_LANES + 19), BaseField::from(1234)); - assert_eq!(col.at(3 * CHUNK_SIZE * N_LANES + 1), BaseField::from(5678)); - } - - #[test] - fn test_secure_column_by_coords_chunks_mut() { - const COL_PACKED_SIZE: usize = 16; - let a: [BaseField; N_LANES * COL_PACKED_SIZE] = array::from_fn(BaseField::from); - let b: [BaseField; N_LANES * COL_PACKED_SIZE] = array::from_fn(BaseField::from); - let c: [BaseField; N_LANES * COL_PACKED_SIZE] = array::from_fn(BaseField::from); - let d: [BaseField; N_LANES * COL_PACKED_SIZE] = array::from_fn(BaseField::from); - let mut col = SecureColumnByCoords { - columns: [a, b, c, d].map(|values| values.into_iter().collect::()), - }; - - let mut rng = SmallRng::seed_from_u64(0); - let rand0 = PackedQM31::from_array(rng.gen()); - let rand1 = PackedQM31::from_array(rng.gen()); - - const CHUNK_SIZE: usize = 4; - let mut chunks = col.chunks_mut(CHUNK_SIZE); - unsafe { - chunks[2].set_packed(3, rand0); - chunks[3].set_packed(1, rand1); - - assert_eq!( - col.packed_at(2 * CHUNK_SIZE + 3).to_array(), - rand0.to_array() - ); - assert_eq!( - col.packed_at(3 * CHUNK_SIZE + 1).to_array(), - rand1.to_array() - ); - } - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/domain.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/domain.rs deleted file mode 100644 index 2093141..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/simd/domain.rs +++ /dev/null @@ -1,86 +0,0 @@ -use std::simd::{simd_swizzle, u32x2, Simd}; - -use super::m31::{PackedM31, LOG_N_LANES}; -use crate::core::circle::{CirclePoint, M31_CIRCLE_LOG_ORDER}; -use crate::core::fields::m31::M31; -use crate::core::poly::circle::CircleDomain; -use crate::core::utils::bit_reverse_index; - -pub struct CircleDomainBitRevIterator { - domain: CircleDomain, - i: usize, - current: CirclePoint, - flips: [CirclePoint; (M31_CIRCLE_LOG_ORDER - LOG_N_LANES) as usize], -} -impl CircleDomainBitRevIterator { - pub fn new(domain: CircleDomain) -> Self { - let log_size = domain.log_size(); - assert!(log_size >= LOG_N_LANES); - - let initial_points = std::array::from_fn(|i| domain.at(bit_reverse_index(i, log_size))); - let current = CirclePoint { - x: PackedM31::from_array(initial_points.each_ref().map(|p| p.x)), - y: PackedM31::from_array(initial_points.each_ref().map(|p| p.y)), - }; - - let mut flips = [CirclePoint::zero(); (M31_CIRCLE_LOG_ORDER - LOG_N_LANES) as usize]; - for i in 0..(log_size - LOG_N_LANES) { - // L i - // 000111000000 -> - // 000000100000 - let prev_mul = bit_reverse_index((1 << i) - 1, log_size - LOG_N_LANES); - let new_mul = bit_reverse_index(1 << i, log_size - LOG_N_LANES); - let flip = domain.half_coset.step.mul(new_mul as u128) - - domain.half_coset.step.mul(prev_mul as u128); - flips[i as usize] = flip; - } - Self { - domain, - i: 0, - current, - flips, - } - } -} -impl Iterator for CircleDomainBitRevIterator { - type Item = CirclePoint; - - fn next(&mut self) -> Option { - if self.i << LOG_N_LANES >= self.domain.size() { - return None; - } - let res = self.current; - let flip = self.flips[self.i.trailing_ones() as usize]; - let flipx = Simd::splat(flip.x.0); - let flipy = u32x2::from_array([flip.y.0, (-flip.y).0]); - let flipy = simd_swizzle!(flipy, [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]); - let flip = unsafe { - CirclePoint { - x: PackedM31::from_simd_unchecked(flipx), - y: PackedM31::from_simd_unchecked(flipy), - } - }; - self.current = self.current + flip; - self.i += 1; - Some(res) - } -} - -#[test] -fn test_circle_domain_bit_rev_iterator() { - let domain = CircleDomain::new(crate::core::circle::Coset::new( - crate::core::circle::CirclePointIndex::generator(), - 5, - )); - let mut expected = domain.iter().collect::>(); - crate::core::utils::bit_reverse(&mut expected); - let actual = CircleDomainBitRevIterator::new(domain) - .flat_map(|c| -> [_; 16] { - std::array::from_fn(|i| CirclePoint { - x: c.x.to_array()[i], - y: c.y.to_array()[i], - }) - }) - .collect::>(); - assert_eq!(actual, expected); -} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/fft/ifft.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/fft/ifft.rs deleted file mode 100644 index eb34da4..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/simd/fft/ifft.rs +++ /dev/null @@ -1,712 +0,0 @@ -//! Inverse fft. - -use std::simd::{simd_swizzle, u32x16, u32x2, u32x4}; - -use itertools::Itertools; - -use super::{ - compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE, -}; -use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; -use crate::core::circle::Coset; -use crate::core::fields::FieldExpOps; -use crate::core::utils::bit_reverse; - -/// Performs an Inverse Circle Fast Fourier Transform (ICFFT) on the given values. -/// -/// # Arguments -/// -/// - `values`: A mutable pointer to the values on which the ICFFT is to be performed. -/// - `twiddle_dbl`: A reference to the doubles of the twiddle factors. -/// - `log_n_elements`: The log of the number of elements in the `values` array. -/// -/// # Panics -/// -/// Panic if `log_n_elements` is less than [`MIN_FFT_LOG_SIZE`]. -/// -/// # Safety -/// -/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. -pub unsafe fn ifft(values: *mut u32, twiddle_dbl: &[&[u32]], log_n_elements: usize) { - assert!(log_n_elements >= MIN_FFT_LOG_SIZE as usize); - let log_n_vecs = log_n_elements - LOG_N_LANES as usize; - if log_n_elements <= CACHED_FFT_LOG_SIZE as usize { - ifft_lower_with_vecwise(values, twiddle_dbl, log_n_elements, log_n_elements); - return; - } - - let fft_layers_pre_transpose = log_n_vecs.div_ceil(2); - let fft_layers_post_transpose = log_n_vecs / 2; - ifft_lower_with_vecwise( - values, - &twiddle_dbl[..3 + fft_layers_pre_transpose], - log_n_elements, - fft_layers_pre_transpose + LOG_N_LANES as usize, - ); - transpose_vecs(values, log_n_vecs); - ifft_lower_without_vecwise( - values, - &twiddle_dbl[3 + fft_layers_pre_transpose..], - log_n_elements, - fft_layers_post_transpose, - ); -} - -/// Computes partial ifft on `2^log_size` M31 elements. -/// -/// # Arguments -/// -/// - `values`: Pointer to the entire value array, aligned to 64 bytes. -/// - `twiddle_dbl`: The doubles of the twiddle factors for each layer of the the ifft. Layer i -/// holds `2^(log_size - 1 - i)` twiddles. -/// - `log_size`: The log of the number of number of M31 elements in the array. -/// - `fft_layers`: The number of ifft layers to apply, out of log_size. -/// -/// # Panics -/// -/// Panics if `log_size` is not at least 5. -/// -/// # Safety -/// -/// `values` must have the same alignment as [`PackedBaseField`]. -/// `fft_layers` must be at least 5. -pub unsafe fn ifft_lower_with_vecwise( - values: *mut u32, - twiddle_dbl: &[&[u32]], - log_size: usize, - fft_layers: usize, -) { - const VECWISE_FFT_BITS: usize = LOG_N_LANES as usize + 1; - assert!(log_size >= VECWISE_FFT_BITS); - - assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2)); - - for index_h in 0..1 << (log_size - fft_layers) { - ifft_vecwise_loop(values, twiddle_dbl, fft_layers - VECWISE_FFT_BITS, index_h); - for layer in (VECWISE_FFT_BITS..fft_layers).step_by(3) { - match fft_layers - layer { - 1 => { - ifft1_loop(values, &twiddle_dbl[(layer - 1)..], layer, index_h); - } - 2 => { - ifft2_loop(values, &twiddle_dbl[(layer - 1)..], layer, index_h); - } - _ => { - ifft3_loop( - values, - &twiddle_dbl[(layer - 1)..], - fft_layers - layer - 3, - layer, - index_h, - ); - } - } - } - } -} - -/// Computes partial ifft on `2^log_size` M31 elements, skipping the vecwise layers (lower 4 bits of -/// the index). -/// -/// # Arguments -/// -/// - `values`: Pointer to the entire value array, aligned to 64 bytes. -/// - `twiddle_dbl`: The doubles of the twiddle factors for each layer of the the ifft. -/// - `log_size`: The log of the number of number of M31 elements in the array. -/// - `fft_layers`: The number of ifft layers to apply, out of `log_size - LOG_N_LANES`. -/// -/// # Panics -/// -/// Panics if `log_size` is not at least 4. -/// -/// # Safety -/// -/// `values` must have the same alignment as [`PackedBaseField`]. -/// `fft_layers` must be at least 4. -pub unsafe fn ifft_lower_without_vecwise( - values: *mut u32, - twiddle_dbl: &[&[u32]], - log_size: usize, - fft_layers: usize, -) { - assert!(log_size >= LOG_N_LANES as usize); - - for index_h in 0..1 << (log_size - fft_layers - LOG_N_LANES as usize) { - for layer in (0..fft_layers).step_by(3) { - let fixed_layer = layer + LOG_N_LANES as usize; - match fft_layers - layer { - 1 => { - ifft1_loop(values, &twiddle_dbl[layer..], fixed_layer, index_h); - } - 2 => { - ifft2_loop(values, &twiddle_dbl[layer..], fixed_layer, index_h); - } - _ => { - ifft3_loop( - values, - &twiddle_dbl[layer..], - fft_layers - layer - 3, - fixed_layer, - index_h, - ); - } - } - } - } -} - -/// Runs the first 5 ifft layers across the entire array. -/// -/// # Arguments -/// -/// - `values`: Pointer to the entire value array, aligned to 64 bytes. -/// - `twiddle_dbl`: The doubles of the twiddle factors for each of the 5 ifft layers. -/// - `high_bits`: The number of bits this loops needs to run on. -/// - `index_h`: The higher part of the index, iterated by the caller. -/// -/// # Safety -/// -/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. -pub unsafe fn ifft_vecwise_loop( - values: *mut u32, - twiddle_dbl: &[&[u32]], - loop_bits: usize, - index_h: usize, -) { - for index_l in 0..1 << loop_bits { - let index = (index_h << loop_bits) + index_l; - let mut val0 = PackedBaseField::load(values.add(index * 32).cast_const()); - let mut val1 = PackedBaseField::load(values.add(index * 32 + 16).cast_const()); - (val0, val1) = vecwise_ibutterflies( - val0, - val1, - std::array::from_fn(|i| *twiddle_dbl[0].get_unchecked(index * 8 + i)), - std::array::from_fn(|i| *twiddle_dbl[1].get_unchecked(index * 4 + i)), - std::array::from_fn(|i| *twiddle_dbl[2].get_unchecked(index * 2 + i)), - ); - (val0, val1) = simd_ibutterfly( - val0, - val1, - u32x16::splat(*twiddle_dbl[3].get_unchecked(index)), - ); - val0.store(values.add(index * 32)); - val1.store(values.add(index * 32 + 16)); - } -} - -/// Runs 3 ifft layers across the entire array. -/// -/// # Arguments -/// -/// - `values`: Pointer to the entire value array, aligned to 64 bytes. -/// - `twiddle_dbl`: The doubles of the twiddle factors for each of the 3 ifft layers. -/// - `loop_bits`: The number of bits this loops needs to run on. -/// - `layer`: The layer number of the first ifft layer to apply. The layers `layer`, `layer + 1`, -/// `layer + 2` are applied. -/// - `index_h`: The higher part of the index, iterated by the caller. -/// -/// # Safety -/// -/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. -pub unsafe fn ifft3_loop( - values: *mut u32, - twiddle_dbl: &[&[u32]], - loop_bits: usize, - layer: usize, - index_h: usize, -) { - for index_l in 0..1 << loop_bits { - let index = (index_h << loop_bits) + index_l; - let offset = index << (layer + 3); - for l in (0..1 << layer).step_by(1 << LOG_N_LANES as usize) { - ifft3( - values, - offset + l, - layer, - std::array::from_fn(|i| { - *twiddle_dbl[0].get_unchecked((index * 4 + i) & (twiddle_dbl[0].len() - 1)) - }), - std::array::from_fn(|i| { - *twiddle_dbl[1].get_unchecked((index * 2 + i) & (twiddle_dbl[1].len() - 1)) - }), - std::array::from_fn(|i| { - *twiddle_dbl[2].get_unchecked((index + i) & (twiddle_dbl[2].len() - 1)) - }), - ); - } - } -} - -/// Runs 2 ifft layers across the entire array. -/// -/// # Arguments -/// -/// - `values`: Pointer to the entire value array, aligned to 64 bytes. -/// - `twiddle_dbl`: The doubles of the twiddle factors for each of the 2 ifft layers. -/// - `loop_bits`: The number of bits this loops needs to run on. -/// - `layer`: The layer number of the first ifft layer to apply. The layers `layer`, `layer + 1` -/// are applied. -/// - `index`: The index, iterated by the caller. -/// -/// # Safety -/// -/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. -unsafe fn ifft2_loop(values: *mut u32, twiddle_dbl: &[&[u32]], layer: usize, index: usize) { - let offset = index << (layer + 2); - for l in (0..1 << layer).step_by(1 << LOG_N_LANES as usize) { - ifft2( - values, - offset + l, - layer, - std::array::from_fn(|i| { - *twiddle_dbl[0].get_unchecked((index * 2 + i) & (twiddle_dbl[0].len() - 1)) - }), - std::array::from_fn(|i| { - *twiddle_dbl[1].get_unchecked((index + i) & (twiddle_dbl[1].len() - 1)) - }), - ); - } -} - -/// Runs 1 ifft layer across the entire array. -/// -/// # Arguments -/// -/// - `values`: Pointer to the entire value array, aligned to 64 bytes. -/// - `twiddle_dbl`: The doubles of the twiddle factors for the ifft layer. -/// - `layer`: The layer number of the ifft layer to apply. -/// - `index_h`: The higher part of the index, iterated by the caller. -/// -/// # Safety -/// -/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. -unsafe fn ifft1_loop(values: *mut u32, twiddle_dbl: &[&[u32]], layer: usize, index: usize) { - let offset = index << (layer + 1); - for l in (0..1 << layer).step_by(1 << LOG_N_LANES as usize) { - ifft1( - values, - offset + l, - layer, - std::array::from_fn(|i| { - *twiddle_dbl[0].get_unchecked((index + i) & (twiddle_dbl[0].len() - 1)) - }), - ); - } -} - -/// Computes the ibutterfly operation for packed M31 elements. -/// -/// Returns `val0 + val1, t (val0 - val1)`. `val0, val1` are packed M31 elements. 16 M31 words at -/// each. Each value is assumed to be in unreduced form, [0, P] including P. `twiddle_dbl` holds 16 -/// values, each is a *double* of a twiddle factor, in unreduced form. -pub fn simd_ibutterfly( - val0: PackedBaseField, - val1: PackedBaseField, - twiddle_dbl: u32x16, -) -> (PackedBaseField, PackedBaseField) { - let r0 = val0 + val1; - let r1 = val0 - val1; - let prod = mul_twiddle(r1, twiddle_dbl); - (r0, prod) -} - -/// Runs ifft on 2 vectors of 16 M31 elements. -/// -/// This amounts to 4 butterfly layers, each with 16 butterflies. -/// Each of the vectors represents a bit reversed evaluation. -/// Each value in a vectors is in unreduced form: [0, P] including P. -/// Takes 3 twiddle arrays, one for each layer after the first, holding the double of the -/// corresponding twiddle. -/// The first layer's twiddles (lower bit of the index) are computed from the second layer's -/// twiddles. The second layer takes 8 twiddles. -/// The third layer takes 4 twiddles. -/// The fourth layer takes 2 twiddles. -pub fn vecwise_ibutterflies( - mut val0: PackedBaseField, - mut val1: PackedBaseField, - twiddle1_dbl: [u32; 8], - twiddle2_dbl: [u32; 4], - twiddle3_dbl: [u32; 2], -) -> (PackedBaseField, PackedBaseField) { - // TODO(spapini): The permute can be fused with the _mm512_srli_epi64 inside the butterfly. - - // Each `ibutterfly` take 2 512-bit registers, and does 16 butterflies element by element. - // We need to permute the 512-bit registers to get the right order for the butterflies. - // Denote the index of the 16 M31 elements in register i as i:abcd. - // At each layer we apply the following permutation to the index: - // i:abcd => d:iabc - // This is how it looks like at each iteration. - // i:abcd - // d:iabc - // ifft on d - // c:diab - // ifft on c - // b:cdia - // ifft on b - // a:bcid - // ifft on a - // i:abcd - - let (t0, t1) = compute_first_twiddles(twiddle1_dbl.into()); - - // Apply the permutation, resulting in indexing d:iabc. - (val0, val1) = val0.deinterleave(val1); - (val0, val1) = simd_ibutterfly(val0, val1, t0); - - // Apply the permutation, resulting in indexing c:diab. - (val0, val1) = val0.deinterleave(val1); - (val0, val1) = simd_ibutterfly(val0, val1, t1); - - let t = simd_swizzle!( - u32x4::from(twiddle2_dbl), - [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3] - ); - // Apply the permutation, resulting in indexing b:cdia. - (val0, val1) = val0.deinterleave(val1); - (val0, val1) = simd_ibutterfly(val0, val1, t); - - let t = simd_swizzle!( - u32x2::from(twiddle3_dbl), - [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] - ); - // Apply the permutation, resulting in indexing a:bcid. - (val0, val1) = val0.deinterleave(val1); - (val0, val1) = simd_ibutterfly(val0, val1, t); - - // Apply the permutation, resulting in indexing i:abcd. - val0.deinterleave(val1) -} - -/// Returns the line twiddles (x points) for an ifft on a coset. -pub fn get_itwiddle_dbls(mut coset: Coset) -> Vec> { - let mut res = vec![]; - for _ in 0..coset.log_size() { - res.push( - coset - .iter() - .take(coset.size() / 2) - .map(|p| p.x.inverse().0 * 2) - .collect_vec(), - ); - bit_reverse(res.last_mut().unwrap()); - coset = coset.double(); - } - - res -} - -/// Applies 3 ibutterfly layers on 8 vectors of 16 M31 elements. -/// -/// Vectorized over the 16 elements of the vectors. -/// Used for radix-8 ifft. -/// Each butterfly layer, has 3 SIMD butterflies. -/// Total of 12 SIMD butterflies. -/// -/// # Arguments -/// -/// - `values`: Pointer to the entire value array. -/// - `offset`: The offset of the first value in the array. -/// - `log_step`: The log of the distance in the array, in M31 elements, between each pair of values -/// that need to be transformed. For layer i this is i - 4. -/// - `twiddles_dbl0/1/2`: The double of the twiddles for the 3 layers of ibutterflies. Each layer -/// has 4/2/1 twiddles. -/// -/// # Safety -/// -/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. -pub unsafe fn ifft3( - values: *mut u32, - offset: usize, - log_step: usize, - twiddles_dbl0: [u32; 4], - twiddles_dbl1: [u32; 2], - twiddles_dbl2: [u32; 1], -) { - // Load the 8 SIMD vectors from the array. - let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const()); - let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const()); - let mut val2 = PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const()); - let mut val3 = PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const()); - let mut val4 = PackedBaseField::load(values.add(offset + (4 << log_step)).cast_const()); - let mut val5 = PackedBaseField::load(values.add(offset + (5 << log_step)).cast_const()); - let mut val6 = PackedBaseField::load(values.add(offset + (6 << log_step)).cast_const()); - let mut val7 = PackedBaseField::load(values.add(offset + (7 << log_step)).cast_const()); - - // Apply the first layer of ibutterflies. - (val0, val1) = simd_ibutterfly(val0, val1, u32x16::splat(twiddles_dbl0[0])); - (val2, val3) = simd_ibutterfly(val2, val3, u32x16::splat(twiddles_dbl0[1])); - (val4, val5) = simd_ibutterfly(val4, val5, u32x16::splat(twiddles_dbl0[2])); - (val6, val7) = simd_ibutterfly(val6, val7, u32x16::splat(twiddles_dbl0[3])); - - // Apply the second layer of ibutterflies. - (val0, val2) = simd_ibutterfly(val0, val2, u32x16::splat(twiddles_dbl1[0])); - (val1, val3) = simd_ibutterfly(val1, val3, u32x16::splat(twiddles_dbl1[0])); - (val4, val6) = simd_ibutterfly(val4, val6, u32x16::splat(twiddles_dbl1[1])); - (val5, val7) = simd_ibutterfly(val5, val7, u32x16::splat(twiddles_dbl1[1])); - - // Apply the third layer of ibutterflies. - (val0, val4) = simd_ibutterfly(val0, val4, u32x16::splat(twiddles_dbl2[0])); - (val1, val5) = simd_ibutterfly(val1, val5, u32x16::splat(twiddles_dbl2[0])); - (val2, val6) = simd_ibutterfly(val2, val6, u32x16::splat(twiddles_dbl2[0])); - (val3, val7) = simd_ibutterfly(val3, val7, u32x16::splat(twiddles_dbl2[0])); - - // Store the 8 SIMD vectors back to the array. - val0.store(values.add(offset + (0 << log_step))); - val1.store(values.add(offset + (1 << log_step))); - val2.store(values.add(offset + (2 << log_step))); - val3.store(values.add(offset + (3 << log_step))); - val4.store(values.add(offset + (4 << log_step))); - val5.store(values.add(offset + (5 << log_step))); - val6.store(values.add(offset + (6 << log_step))); - val7.store(values.add(offset + (7 << log_step))); -} - -/// Applies 2 ibutterfly layers on 4 vectors of 16 M31 elements. -/// -/// Vectorized over the 16 elements of the vectors. -/// Used for radix-4 ifft. -/// Each ibutterfly layer, has 2 SIMD butterflies. -/// Total of 4 SIMD butterflies. -/// -/// # Arguments -/// -/// - `values`: Pointer to the entire value array. -/// - `offset`: The offset of the first value in the array. -/// - `log_step`: The log of the distance in the array, in M31 elements, between each pair of values -/// that need to be transformed. For layer `i` this is `i - 4`. -/// - `twiddles_dbl0/1`: The double of the twiddles for the 2 layers of ibutterflies. Each layer has -/// 2/1 twiddles. -/// -/// # Safety -/// -/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. -pub unsafe fn ifft2( - values: *mut u32, - offset: usize, - log_step: usize, - twiddles_dbl0: [u32; 2], - twiddles_dbl1: [u32; 1], -) { - // Load the 4 SIMD vectors from the array. - let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const()); - let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const()); - let mut val2 = PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const()); - let mut val3 = PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const()); - - // Apply the first layer of butterflies. - (val0, val1) = simd_ibutterfly(val0, val1, u32x16::splat(twiddles_dbl0[0])); - (val2, val3) = simd_ibutterfly(val2, val3, u32x16::splat(twiddles_dbl0[1])); - - // Apply the second layer of butterflies. - (val0, val2) = simd_ibutterfly(val0, val2, u32x16::splat(twiddles_dbl1[0])); - (val1, val3) = simd_ibutterfly(val1, val3, u32x16::splat(twiddles_dbl1[0])); - - // Store the 4 SIMD vectors back to the array. - val0.store(values.add(offset + (0 << log_step))); - val1.store(values.add(offset + (1 << log_step))); - val2.store(values.add(offset + (2 << log_step))); - val3.store(values.add(offset + (3 << log_step))); -} - -/// Applies 1 ibutterfly layers on 2 vectors of 16 M31 elements. -/// -/// Vectorized over the 16 elements of the vectors. -/// -/// # Arguments -/// -/// - `values`: Pointer to the entire value array. -/// - `offset`: The offset of the first value in the array. -/// - `log_step`: The log of the distance in the array, in M31 elements, between each pair of values -/// that need to be transformed. For layer `i` this is `i - 4`. -/// - `twiddles_dbl0`: The double of the twiddles for the ibutterfly layer. -/// -/// # Safety -/// -/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. -pub unsafe fn ifft1(values: *mut u32, offset: usize, log_step: usize, twiddles_dbl0: [u32; 1]) { - // Load the 2 SIMD vectors from the array. - let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const()); - let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const()); - - (val0, val1) = simd_ibutterfly(val0, val1, u32x16::splat(twiddles_dbl0[0])); - - // Store the 2 SIMD vectors back to the array. - val0.store(values.add(offset + (0 << log_step))); - val1.store(values.add(offset + (1 << log_step))); -} - -#[cfg(test)] -mod tests { - use std::mem::transmute; - use std::simd::u32x16; - - use itertools::Itertools; - use rand::rngs::SmallRng; - use rand::{Rng, SeedableRng}; - - use super::{ - get_itwiddle_dbls, ifft, ifft3, ifft_lower_with_vecwise, simd_ibutterfly, - vecwise_ibutterflies, - }; - use crate::core::backend::cpu::CpuCircleEvaluation; - use crate::core::backend::simd::column::BaseColumn; - use crate::core::backend::simd::fft::{transpose_vecs, CACHED_FFT_LOG_SIZE}; - use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES, N_LANES}; - use crate::core::backend::Column; - use crate::core::fft::ibutterfly as ground_truth_ibutterfly; - use crate::core::fields::m31::BaseField; - use crate::core::poly::circle::{CanonicCoset, CircleDomain}; - - #[test] - fn test_ibutterfly() { - let mut rng = SmallRng::seed_from_u64(0); - let mut v0: [BaseField; N_LANES] = rng.gen(); - let mut v1: [BaseField; N_LANES] = rng.gen(); - let twiddle: [BaseField; N_LANES] = rng.gen(); - let twiddle_dbl = twiddle.map(|v| v.0 * 2); - - let (r0, r1) = simd_ibutterfly(v0.into(), v1.into(), twiddle_dbl.into()); - - let r0 = r0.to_array(); - let r1 = r1.to_array(); - for i in 0..N_LANES { - ground_truth_ibutterfly(&mut v0[i], &mut v1[i], twiddle[i]); - assert_eq!((v0[i], v1[i]), (r0[i], r1[i]), "mismatch at i={i}"); - } - } - - #[test] - fn test_ifft3() { - let mut rng = SmallRng::seed_from_u64(0); - let values = rng.gen::<[BaseField; 8]>().map(PackedBaseField::broadcast); - let twiddles0: [BaseField; 4] = rng.gen(); - let twiddles1: [BaseField; 2] = rng.gen(); - let twiddles2: [BaseField; 1] = rng.gen(); - let twiddles0_dbl = twiddles0.map(|v| v.0 * 2); - let twiddles1_dbl = twiddles1.map(|v| v.0 * 2); - let twiddles2_dbl = twiddles2.map(|v| v.0 * 2); - - let mut res = values; - unsafe { - ifft3( - transmute(res.as_mut_ptr()), - 0, - LOG_N_LANES as usize, - twiddles0_dbl, - twiddles1_dbl, - twiddles2_dbl, - ) - }; - - let mut expected = values.map(|v| v.to_array()[0]); - for i in 0..8 { - let j = i ^ 1; - if i > j { - continue; - } - let (mut v0, mut v1) = (expected[i], expected[j]); - ground_truth_ibutterfly(&mut v0, &mut v1, twiddles0[i / 2]); - (expected[i], expected[j]) = (v0, v1); - } - for i in 0..8 { - let j = i ^ 2; - if i > j { - continue; - } - let (mut v0, mut v1) = (expected[i], expected[j]); - ground_truth_ibutterfly(&mut v0, &mut v1, twiddles1[i / 4]); - (expected[i], expected[j]) = (v0, v1); - } - for i in 0..8 { - let j = i ^ 4; - if i > j { - continue; - } - let (mut v0, mut v1) = (expected[i], expected[j]); - ground_truth_ibutterfly(&mut v0, &mut v1, twiddles2[0]); - (expected[i], expected[j]) = (v0, v1); - } - for i in 0..8 { - assert_eq!( - res[i].to_array(), - [expected[i]; N_LANES], - "mismatch at i={i}" - ); - } - } - - #[test] - fn test_vecwise_ibutterflies() { - let domain = CanonicCoset::new(5).circle_domain(); - let twiddle_dbls = get_itwiddle_dbls(domain.half_coset); - assert_eq!(twiddle_dbls.len(), 4); - let mut rng = SmallRng::seed_from_u64(0); - let values: [[BaseField; 16]; 2] = rng.gen(); - - let res = { - let (val0, val1) = vecwise_ibutterflies( - values[0].into(), - values[1].into(), - twiddle_dbls[0].clone().try_into().unwrap(), - twiddle_dbls[1].clone().try_into().unwrap(), - twiddle_dbls[2].clone().try_into().unwrap(), - ); - let (val0, val1) = simd_ibutterfly(val0, val1, u32x16::splat(twiddle_dbls[3][0])); - [val0.to_array(), val1.to_array()].concat() - }; - - assert_eq!(res, ground_truth_ifft(domain, values.flatten())); - } - - #[test] - fn test_ifft_lower_with_vecwise() { - for log_size in 5..12 { - let domain = CanonicCoset::new(log_size).circle_domain(); - let mut rng = SmallRng::seed_from_u64(0); - let values = (0..domain.size()).map(|_| rng.gen()).collect_vec(); - let twiddle_dbls = get_itwiddle_dbls(domain.half_coset); - - let mut res = values.iter().copied().collect::(); - unsafe { - ifft_lower_with_vecwise( - transmute(res.data.as_mut_ptr()), - &twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(), - log_size as usize, - log_size as usize, - ); - } - - assert_eq!(res.to_cpu(), ground_truth_ifft(domain, &values)); - } - } - - #[test] - fn test_ifft_full() { - for log_size in CACHED_FFT_LOG_SIZE + 1..CACHED_FFT_LOG_SIZE + 3 { - let domain = CanonicCoset::new(log_size).circle_domain(); - let mut rng = SmallRng::seed_from_u64(0); - let values = (0..domain.size()).map(|_| rng.gen()).collect_vec(); - let twiddle_dbls = get_itwiddle_dbls(domain.half_coset); - - let mut res = values.iter().copied().collect::(); - unsafe { - ifft( - transmute(res.data.as_mut_ptr()), - &twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(), - log_size as usize, - ); - transpose_vecs(transmute(res.data.as_mut_ptr()), log_size as usize - 4); - } - - assert_eq!(res.to_cpu(), ground_truth_ifft(domain, &values)); - } - } - - fn ground_truth_ifft(domain: CircleDomain, values: &[BaseField]) -> Vec { - let eval = CpuCircleEvaluation::new(domain, values.to_vec()); - let mut res = eval.interpolate().coeffs; - let denorm = BaseField::from(domain.size()); - res.iter_mut().for_each(|v| *v *= denorm); - res - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/fft/mod.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/fft/mod.rs deleted file mode 100644 index ca44979..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/simd/fft/mod.rs +++ /dev/null @@ -1,120 +0,0 @@ -use std::simd::{simd_swizzle, u32x16, u32x8}; - -use super::m31::PackedBaseField; -use crate::core::fields::m31::P; - -pub mod ifft; -pub mod rfft; - -pub const CACHED_FFT_LOG_SIZE: u32 = 16; - -pub const MIN_FFT_LOG_SIZE: u32 = 5; - -// TODO(spapini): FFTs return a redundant representation, that can get the value P. need to reduce -// it somewhere. - -/// Transposes the SIMD vectors in the given array. -/// -/// Swaps the bit index abc <-> cba, where |a|=|c| and |b| = 0 or 1, according to the parity of -/// `log_n_vecs`. -/// When log_n_vecs is odd, transforms the index abc <-> cba, w -/// -/// # Arguments -/// -/// - `values`: A mutable pointer to the values that are to be transposed. -/// - `log_n_vecs`: The log of the number of SIMD vectors in the `values` array. -/// -/// # Safety -/// -/// Behavior is undefined if `values` does not have the same alignment as [`u32x16`]. -pub unsafe fn transpose_vecs(values: *mut u32, log_n_vecs: usize) { - let half = log_n_vecs / 2; - for b in 0..1 << (log_n_vecs & 1) { - for a in 0..1 << half { - for c in 0..1 << half { - let i = (a << (log_n_vecs - half)) | (b << half) | c; - let j = (c << (log_n_vecs - half)) | (b << half) | a; - if i >= j { - continue; - } - let val0 = load(values.add(i << 4).cast_const()); - let val1 = load(values.add(j << 4).cast_const()); - store(values.add(i << 4), val1); - store(values.add(j << 4), val0); - } - } - } -} - -/// Computes the twiddles for the first fft layer from the second, and loads both to SIMD registers. -/// -/// Returns the twiddles for the first layer and the twiddles for the second layer. -pub fn compute_first_twiddles(twiddle1_dbl: u32x8) -> (u32x16, u32x16) { - // Start by loading the twiddles for the second layer (layer 1): - let t1 = simd_swizzle!( - twiddle1_dbl, - twiddle1_dbl, - [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7] - ); - - // The twiddles for layer 0 can be computed from the twiddles for layer 1. - // Since the twiddles are bit reversed, we consider the circle domain in bit reversed order. - // Each consecutive 4 points in the bit reversed order of a coset form a circle coset of size 4. - // A circle coset of size 4 in bit reversed order looks like this: - // [(x, y), (-x, -y), (y, -x), (-y, x)] - // Note: This is related to the choice of M31_CIRCLE_GEN, and the fact the a quarter rotation - // is (0,-1) and not (0,1). (0,1) would yield another relation. - // The twiddles for layer 0 are the y coordinates: - // [y, -y, -x, x] - // The twiddles for layer 1 in bit reversed order are the x coordinates: - // [x, y] - // Works also for inverse of the twiddles. - - // The twiddles for layer 0 are computed like this: - // t0[4i:4i+3] = [t1[2i+1], -t1[2i+1], -t1[2i], t1[2i]] - // Xoring a double twiddle with P*2 transforms it to the double of it negation. - // Note that this keeps the values as a double of a value in the range [0, P]. - const P2: u32 = P * 2; - const NEGATION_MASK: u32x16 = - u32x16::from_array([0, P2, P2, 0, 0, P2, P2, 0, 0, P2, P2, 0, 0, P2, P2, 0]); - let t0 = simd_swizzle!( - t1, - [ - 0b0001, 0b0001, 0b0000, 0b0000, 0b0011, 0b0011, 0b0010, 0b0010, 0b0101, 0b0101, 0b0100, - 0b0100, 0b0111, 0b0111, 0b0110, 0b0110, - ] - ) ^ NEGATION_MASK; - (t0, t1) -} - -#[inline] -unsafe fn load(mem_addr: *const u32) -> u32x16 { - std::ptr::read(mem_addr as *const u32x16) -} - -#[inline] -unsafe fn store(mem_addr: *mut u32, a: u32x16) { - std::ptr::write(mem_addr as *mut u32x16, a); -} - -/// Computes `v * twiddle` -fn mul_twiddle(v: PackedBaseField, twiddle_dbl: u32x16) -> PackedBaseField { - // TODO: Come up with a better approach than `cfg`ing on target_feature. - // TODO: Ensure all these branches get tested in the CI. - cfg_if::cfg_if! { - if #[cfg(all(target_feature = "neon", target_arch = "aarch64"))] { - // TODO: For architectures that when multiplying require doubling then the twiddles - // should be precomputed as double. For other architectures, the twiddle should be - // precomputed without doubling. - crate::core::backend::simd::m31::_mul_doubled_neon(v, twiddle_dbl) - } else if #[cfg(all(target_feature = "simd128", target_arch = "wasm32"))] { - crate::core::backend::simd::m31::_mul_doubled_wasm(v, twiddle_dbl) - } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] { - crate::core::backend::simd::m31::_mul_doubled_avx512(v, twiddle_dbl) - } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] { - crate::core::backend::simd::m31::_mul_doubled_avx2(v, twiddle_dbl) - } else { - crate::core::backend::simd::m31::_mul_doubled_simd(v, twiddle_dbl) - } - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/fft/rfft.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/fft/rfft.rs deleted file mode 100644 index 6d51fd0..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/simd/fft/rfft.rs +++ /dev/null @@ -1,742 +0,0 @@ -//! Regular (forward) fft. - -use std::array; -use std::simd::{simd_swizzle, u32x16, u32x2, u32x4, u32x8}; - -use itertools::Itertools; - -use super::{ - compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE, -}; -use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; -use crate::core::circle::Coset; -use crate::core::utils::bit_reverse; - -/// Performs a Circle Fast Fourier Transform (CFFT) on the given values. -/// -/// # Arguments -/// -/// * `src`: A pointer to the values to transform. -/// * `dst`: A pointer to the destination array. -/// * `twiddle_dbl`: A reference to the doubles of the twiddle factors. -/// * `log_n_elements`: The log of the number of elements in the `values` array. -/// -/// # Panics -/// -/// This function will panic if `log_n_elements` is less than `MIN_FFT_LOG_SIZE`. -/// -/// # Safety -/// -/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`]. -pub unsafe fn fft(src: *const u32, dst: *mut u32, twiddle_dbl: &[&[u32]], log_n_elements: usize) { - assert!(log_n_elements >= MIN_FFT_LOG_SIZE as usize); - let log_n_vecs = log_n_elements - LOG_N_LANES as usize; - if log_n_elements <= CACHED_FFT_LOG_SIZE as usize { - fft_lower_with_vecwise(src, dst, twiddle_dbl, log_n_elements, log_n_elements); - return; - } - - let fft_layers_pre_transpose = log_n_vecs.div_ceil(2); - let fft_layers_post_transpose = log_n_vecs / 2; - fft_lower_without_vecwise( - src, - dst, - &twiddle_dbl[(3 + fft_layers_pre_transpose)..], - log_n_elements, - fft_layers_post_transpose, - ); - transpose_vecs(dst, log_n_vecs); - fft_lower_with_vecwise( - dst, - dst, - &twiddle_dbl[..3 + fft_layers_pre_transpose], - log_n_elements, - fft_layers_pre_transpose + LOG_N_LANES as usize, - ); -} - -/// Computes partial fft on `2^log_size` M31 elements. -/// -/// # Arguments -/// -/// - `src`: A pointer to the values to transform, aligned to 64 bytes. -/// - `dst`: A pointer to the destination array, aligned to 64 bytes. -/// - `twiddle_dbl`: The doubles of the twiddle factors for each layer of the the fft. Layer `i` -/// holds `2^(log_size - 1 - i)` twiddles. -/// - `log_size`: The log of the number of number of M31 elements in the array. -/// - `fft_layers`: The number of fft layers to apply, out of log_size. -/// -/// # Panics -/// -/// Panics if `log_size` is not at least 5. -/// -/// # Safety -/// -/// `src` and `dst` must have same alignment as [`PackedBaseField`]. -/// `fft_layers` must be at least 5. -pub unsafe fn fft_lower_with_vecwise( - src: *const u32, - dst: *mut u32, - twiddle_dbl: &[&[u32]], - log_size: usize, - fft_layers: usize, -) { - const VECWISE_FFT_BITS: usize = LOG_N_LANES as usize + 1; - assert!(log_size >= VECWISE_FFT_BITS); - - assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2)); - - for index_h in 0..1 << (log_size - fft_layers) { - let mut src = src; - for layer in (VECWISE_FFT_BITS..fft_layers).step_by(3).rev() { - match fft_layers - layer { - 1 => { - fft1_loop(src, dst, &twiddle_dbl[(layer - 1)..], layer, index_h); - } - 2 => { - fft2_loop(src, dst, &twiddle_dbl[(layer - 1)..], layer, index_h); - } - _ => { - fft3_loop( - src, - dst, - &twiddle_dbl[(layer - 1)..], - fft_layers - layer - 3, - layer, - index_h, - ); - } - } - src = dst; - } - fft_vecwise_loop( - src, - dst, - twiddle_dbl, - fft_layers - VECWISE_FFT_BITS, - index_h, - ); - } -} - -/// Computes partial fft on `2^log_size` M31 elements, skipping the vecwise layers (lower 4 bits of -/// the index). -/// -/// # Arguments -/// -/// - `src`: A pointer to the values to transform, aligned to 64 bytes. -/// - `dst`: A pointer to the destination array, aligned to 64 bytes. -/// - `twiddle_dbl`: The doubles of the twiddle factors for each layer of the the fft. -/// - `log_size`: The log of the number of number of M31 elements in the array. -/// - `fft_layers`: The number of fft layers to apply, out of log_size - VEC_LOG_SIZE. -/// -/// # Panics -/// -/// Panics if `log_size` is not at least 4. -/// -/// # Safety -/// -/// `src` and `dst` must have same alignment as [`PackedBaseField`]. -/// `fft_layers` must be at least 4. -pub unsafe fn fft_lower_without_vecwise( - src: *const u32, - dst: *mut u32, - twiddle_dbl: &[&[u32]], - log_size: usize, - fft_layers: usize, -) { - assert!(log_size >= LOG_N_LANES as usize); - - for index_h in 0..1 << (log_size - fft_layers - LOG_N_LANES as usize) { - let mut src = src; - for layer in (0..fft_layers).step_by(3).rev() { - let fixed_layer = layer + LOG_N_LANES as usize; - match fft_layers - layer { - 1 => { - fft1_loop(src, dst, &twiddle_dbl[layer..], fixed_layer, index_h); - } - 2 => { - fft2_loop(src, dst, &twiddle_dbl[layer..], fixed_layer, index_h); - } - _ => { - fft3_loop( - src, - dst, - &twiddle_dbl[layer..], - fft_layers - layer - 3, - fixed_layer, - index_h, - ); - } - } - src = dst; - } - } -} - -/// Runs the last 5 fft layers across the entire array. -/// -/// # Arguments -/// -/// - `src`: A pointer to the values to transform, aligned to 64 bytes. -/// - `dst`: A pointer to the destination array, aligned to 64 bytes. -/// - `twiddle_dbl`: The doubles of the twiddle factors for each of the 5 fft layers. -/// - `high_bits`: The number of bits this loops needs to run on. -/// - `index_h`: The higher part of the index, iterated by the caller. -/// -/// # Safety -/// -/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`]. -unsafe fn fft_vecwise_loop( - src: *const u32, - dst: *mut u32, - twiddle_dbl: &[&[u32]], - loop_bits: usize, - index_h: usize, -) { - for index_l in 0..1 << loop_bits { - let index = (index_h << loop_bits) + index_l; - let mut val0 = PackedBaseField::load(src.add(index * 32)); - let mut val1 = PackedBaseField::load(src.add(index * 32 + 16)); - (val0, val1) = simd_butterfly( - val0, - val1, - u32x16::splat(*twiddle_dbl[3].get_unchecked(index)), - ); - (val0, val1) = vecwise_butterflies( - val0, - val1, - array::from_fn(|i| *twiddle_dbl[0].get_unchecked(index * 8 + i)), - array::from_fn(|i| *twiddle_dbl[1].get_unchecked(index * 4 + i)), - array::from_fn(|i| *twiddle_dbl[2].get_unchecked(index * 2 + i)), - ); - val0.store(dst.add(index * 32)); - val1.store(dst.add(index * 32 + 16)); - } -} - -/// Runs 3 fft layers across the entire array. -/// -/// # Arguments -/// -/// - `src`: A pointer to the values to transform, aligned to 64 bytes. -/// - `dst`: A pointer to the destination array, aligned to 64 bytes. -/// - `twiddle_dbl`: The doubles of the twiddle factors for each of the 3 fft layers. -/// - `loop_bits`: The number of bits this loops needs to run on. -/// - `layer`: The layer number of the first fft layer to apply. The layers `layer`, `layer + 1`, -/// `layer + 2` are applied. -/// - `index_h`: The higher part of the index, iterated by the caller. -/// -/// # Safety -/// -/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`]. -unsafe fn fft3_loop( - src: *const u32, - dst: *mut u32, - twiddle_dbl: &[&[u32]], - loop_bits: usize, - layer: usize, - index_h: usize, -) { - for index_l in 0..1 << loop_bits { - let index = (index_h << loop_bits) + index_l; - let offset = index << (layer + 3); - for l in (0..1 << layer).step_by(1 << LOG_N_LANES as usize) { - fft3( - src, - dst, - offset + l, - layer, - array::from_fn(|i| { - *twiddle_dbl[0].get_unchecked((index * 4 + i) & (twiddle_dbl[0].len() - 1)) - }), - array::from_fn(|i| { - *twiddle_dbl[1].get_unchecked((index * 2 + i) & (twiddle_dbl[1].len() - 1)) - }), - array::from_fn(|i| { - *twiddle_dbl[2].get_unchecked((index + i) & (twiddle_dbl[2].len() - 1)) - }), - ); - } - } -} - -/// Runs 2 fft layers across the entire array. -/// -/// # Arguments -/// -/// - `src`: A pointer to the values to transform, aligned to 64 bytes. -/// - `dst`: A pointer to the destination array, aligned to 64 bytes. -/// - `twiddle_dbl`: The doubles of the twiddle factors for each of the 2 fft layers. -/// - `loop_bits`: The number of bits this loops needs to run on. -/// - `layer`: The layer number of the first fft layer to apply. The layers `layer`, `layer + 1` are -/// applied. -/// - `index`: The index, iterated by the caller. -/// -/// # Safety -/// -/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`]. -unsafe fn fft2_loop( - src: *const u32, - dst: *mut u32, - twiddle_dbl: &[&[u32]], - layer: usize, - index: usize, -) { - let offset = index << (layer + 2); - for l in (0..1 << layer).step_by(1 << LOG_N_LANES as usize) { - fft2( - src, - dst, - offset + l, - layer, - array::from_fn(|i| { - *twiddle_dbl[0].get_unchecked((index * 2 + i) & (twiddle_dbl[0].len() - 1)) - }), - array::from_fn(|i| { - *twiddle_dbl[1].get_unchecked((index + i) & (twiddle_dbl[1].len() - 1)) - }), - ); - } -} - -/// Runs 1 fft layer across the entire array. -/// -/// # Arguments -/// -/// - `src`: A pointer to the values to transform, aligned to 64 bytes. -/// - `dst`: A pointer to the destination array, aligned to 64 bytes. -/// - `twiddle_dbl`: The doubles of the twiddle factors for the fft layer. -/// - `layer`: The layer number of the fft layer to apply. -/// - `index_h`: The higher part of the index, iterated by the caller. -/// -/// # Safety -/// -/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`]. -unsafe fn fft1_loop( - src: *const u32, - dst: *mut u32, - twiddle_dbl: &[&[u32]], - layer: usize, - index: usize, -) { - let offset = index << (layer + 1); - for l in (0..1 << layer).step_by(1 << LOG_N_LANES as usize) { - fft1( - src, - dst, - offset + l, - layer, - array::from_fn(|i| { - *twiddle_dbl[0].get_unchecked((index + i) & (twiddle_dbl[0].len() - 1)) - }), - ); - } -} - -/// Computes the butterfly operation for packed M31 elements. -/// -/// Returns `val0 + t val1, val0 - t val1`. `val0, val1` are packed M31 elements. 16 M31 words at -/// each. Each value is assumed to be in unreduced form, [0, P] including P. Returned values are in -/// unreduced form, [0, P] including P. twiddle_dbl holds 16 values, each is a *double* of a twiddle -/// factor, in unreduced form, [0, 2*P]. -pub fn simd_butterfly( - val0: PackedBaseField, - val1: PackedBaseField, - twiddle_dbl: u32x16, -) -> (PackedBaseField, PackedBaseField) { - let prod = mul_twiddle(val1, twiddle_dbl); - (val0 + prod, val0 - prod) -} - -/// Runs fft on 2 vectors of 16 M31 elements. -/// -/// This amounts to 4 butterfly layers, each with 16 butterflies. -/// Each of the vectors represents natural ordered polynomial coefficeint. -/// Each value in a vectors is in unreduced form: [0, P] including P. -/// Takes 4 twiddle arrays, one for each layer, holding the double of the corresponding twiddle. -/// The first layer (higher bit of the index) takes 2 twiddles. -/// The second layer takes 4 twiddles. -/// etc. -pub fn vecwise_butterflies( - mut val0: PackedBaseField, - mut val1: PackedBaseField, - twiddle1_dbl: [u32; 8], - twiddle2_dbl: [u32; 4], - twiddle3_dbl: [u32; 2], -) -> (PackedBaseField, PackedBaseField) { - // TODO(spapini): Compute twiddle0 from twiddle1. - // TODO(spapini): The permute can be fused with the _mm512_srli_epi64 inside the butterfly. - // The implementation is the exact reverse of vecwise_ibutterflies(). - // See the comments in its body for more info. - let t = simd_swizzle!( - u32x2::from(twiddle3_dbl), - [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] - ); - (val0, val1) = val0.interleave(val1); - (val0, val1) = simd_butterfly(val0, val1, t); - - let t = simd_swizzle!( - u32x4::from(twiddle2_dbl), - [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3] - ); - (val0, val1) = val0.interleave(val1); - (val0, val1) = simd_butterfly(val0, val1, t); - - let (t0, t1) = compute_first_twiddles(u32x8::from(twiddle1_dbl)); - (val0, val1) = val0.interleave(val1); - (val0, val1) = simd_butterfly(val0, val1, t1); - - (val0, val1) = val0.interleave(val1); - (val0, val1) = simd_butterfly(val0, val1, t0); - - val0.interleave(val1) -} - -/// Returns the line twiddles (x points) for an fft on a coset. -pub fn get_twiddle_dbls(mut coset: Coset) -> Vec> { - let mut res = vec![]; - for _ in 0..coset.log_size() { - res.push( - coset - .iter() - .take(coset.size() / 2) - .map(|p| p.x.0 * 2) - .collect_vec(), - ); - bit_reverse(res.last_mut().unwrap()); - coset = coset.double(); - } - - res -} - -/// Applies 3 butterfly layers on 8 vectors of 16 M31 elements. -/// -/// Vectorized over the 16 elements of the vectors. -/// Used for radix-8 ifft. -/// Each butterfly layer, has 3 SIMD butterflies. -/// Total of 12 SIMD butterflies. -/// -/// # Arguments -/// -/// - `src`: A pointer to the values to transform, aligned to 64 bytes. -/// - `dst`: A pointer to the destination array, aligned to 64 bytes. -/// - `offset`: The offset of the first value in the array. -/// - `log_step`: The log of the distance in the array, in M31 elements, between each pair of values -/// that need to be transformed. For layer i this is i - 4. -/// - `twiddles_dbl0/1/2`: The double of the twiddles for the 3 layers of butterflies. Each layer -/// has 4/2/1 twiddles. -/// -/// # Safety -/// -/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`]. -pub unsafe fn fft3( - src: *const u32, - dst: *mut u32, - offset: usize, - log_step: usize, - twiddles_dbl0: [u32; 4], - twiddles_dbl1: [u32; 2], - twiddles_dbl2: [u32; 1], -) { - // Load the 8 SIMD vectors from the array. - let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step))); - let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step))); - let mut val2 = PackedBaseField::load(src.add(offset + (2 << log_step))); - let mut val3 = PackedBaseField::load(src.add(offset + (3 << log_step))); - let mut val4 = PackedBaseField::load(src.add(offset + (4 << log_step))); - let mut val5 = PackedBaseField::load(src.add(offset + (5 << log_step))); - let mut val6 = PackedBaseField::load(src.add(offset + (6 << log_step))); - let mut val7 = PackedBaseField::load(src.add(offset + (7 << log_step))); - - // Apply the third layer of butterflies. - (val0, val4) = simd_butterfly(val0, val4, u32x16::splat(twiddles_dbl2[0])); - (val1, val5) = simd_butterfly(val1, val5, u32x16::splat(twiddles_dbl2[0])); - (val2, val6) = simd_butterfly(val2, val6, u32x16::splat(twiddles_dbl2[0])); - (val3, val7) = simd_butterfly(val3, val7, u32x16::splat(twiddles_dbl2[0])); - - // Apply the second layer of butterflies. - (val0, val2) = simd_butterfly(val0, val2, u32x16::splat(twiddles_dbl1[0])); - (val1, val3) = simd_butterfly(val1, val3, u32x16::splat(twiddles_dbl1[0])); - (val4, val6) = simd_butterfly(val4, val6, u32x16::splat(twiddles_dbl1[1])); - (val5, val7) = simd_butterfly(val5, val7, u32x16::splat(twiddles_dbl1[1])); - - // Apply the first layer of butterflies. - (val0, val1) = simd_butterfly(val0, val1, u32x16::splat(twiddles_dbl0[0])); - (val2, val3) = simd_butterfly(val2, val3, u32x16::splat(twiddles_dbl0[1])); - (val4, val5) = simd_butterfly(val4, val5, u32x16::splat(twiddles_dbl0[2])); - (val6, val7) = simd_butterfly(val6, val7, u32x16::splat(twiddles_dbl0[3])); - - // Store the 8 SIMD vectors back to the array. - val0.store(dst.add(offset + (0 << log_step))); - val1.store(dst.add(offset + (1 << log_step))); - val2.store(dst.add(offset + (2 << log_step))); - val3.store(dst.add(offset + (3 << log_step))); - val4.store(dst.add(offset + (4 << log_step))); - val5.store(dst.add(offset + (5 << log_step))); - val6.store(dst.add(offset + (6 << log_step))); - val7.store(dst.add(offset + (7 << log_step))); -} - -/// Applies 2 butterfly layers on 4 vectors of 16 M31 elements. -/// -/// Vectorized over the 16 elements of the vectors. -/// Used for radix-4 fft. -/// Each butterfly layer, has 2 SIMD butterflies. -/// Total of 4 SIMD butterflies. -/// -/// # Arguments -/// -/// - `src`: A pointer to the values to transform, aligned to 64 bytes. -/// - `dst`: A pointer to the destination array, aligned to 64 bytes. -/// - `offset`: The offset of the first value in the array. -/// - `log_step`: The log of the distance in the array, in M31 elements, between each pair of values -/// that need to be transformed. For layer i this is i - 4. -/// - `twiddles_dbl0/1`: The double of the twiddles for the 2 layers of butterflies. Each layer has -/// 2/1 twiddles. -/// -/// # Safety -/// -/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`]. -pub unsafe fn fft2( - src: *const u32, - dst: *mut u32, - offset: usize, - log_step: usize, - twiddles_dbl0: [u32; 2], - twiddles_dbl1: [u32; 1], -) { - // Load the 4 SIMD vectors from the array. - let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step))); - let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step))); - let mut val2 = PackedBaseField::load(src.add(offset + (2 << log_step))); - let mut val3 = PackedBaseField::load(src.add(offset + (3 << log_step))); - - // Apply the second layer of butterflies. - (val0, val2) = simd_butterfly(val0, val2, u32x16::splat(twiddles_dbl1[0])); - (val1, val3) = simd_butterfly(val1, val3, u32x16::splat(twiddles_dbl1[0])); - - // Apply the first layer of butterflies. - (val0, val1) = simd_butterfly(val0, val1, u32x16::splat(twiddles_dbl0[0])); - (val2, val3) = simd_butterfly(val2, val3, u32x16::splat(twiddles_dbl0[1])); - - // Store the 4 SIMD vectors back to the array. - val0.store(dst.add(offset + (0 << log_step))); - val1.store(dst.add(offset + (1 << log_step))); - val2.store(dst.add(offset + (2 << log_step))); - val3.store(dst.add(offset + (3 << log_step))); -} - -/// Applies 1 butterfly layers on 2 vectors of 16 M31 elements. -/// -/// Vectorized over the 16 elements of the vectors. -/// -/// # Arguments -/// -/// - `src`: A pointer to the values to transform, aligned to 64 bytes. -/// - `dst`: A pointer to the destination array, aligned to 64 bytes. -/// - `offset`: The offset of the first value in the array. -/// - `log_step`: The log of the distance in the array, in M31 elements, between each pair of values -/// that need to be transformed. For layer i this is i - 4. -/// - `twiddles_dbl0`: The double of the twiddles for the butterfly layer. -/// -/// # Safety -/// -/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`]. -pub unsafe fn fft1( - src: *const u32, - dst: *mut u32, - offset: usize, - log_step: usize, - twiddles_dbl0: [u32; 1], -) { - // Load the 2 SIMD vectors from the array. - let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step))); - let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step))); - - (val0, val1) = simd_butterfly(val0, val1, u32x16::splat(twiddles_dbl0[0])); - - // Store the 2 SIMD vectors back to the array. - val0.store(dst.add(offset + (0 << log_step))); - val1.store(dst.add(offset + (1 << log_step))); -} - -#[cfg(test)] -mod tests { - use std::mem::transmute; - use std::simd::u32x16; - - use itertools::Itertools; - use rand::rngs::SmallRng; - use rand::{Rng, SeedableRng}; - - use super::{ - fft, fft3, fft_lower_with_vecwise, get_twiddle_dbls, simd_butterfly, vecwise_butterflies, - }; - use crate::core::backend::cpu::CpuCirclePoly; - use crate::core::backend::simd::column::BaseColumn; - use crate::core::backend::simd::fft::{transpose_vecs, CACHED_FFT_LOG_SIZE}; - use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES, N_LANES}; - use crate::core::backend::Column; - use crate::core::fft::butterfly as ground_truth_butterfly; - use crate::core::fields::m31::BaseField; - use crate::core::poly::circle::{CanonicCoset, CircleDomain}; - - #[test] - fn test_butterfly() { - let mut rng = SmallRng::seed_from_u64(0); - let mut v0: [BaseField; N_LANES] = rng.gen(); - let mut v1: [BaseField; N_LANES] = rng.gen(); - let twiddle: [BaseField; N_LANES] = rng.gen(); - let twiddle_dbl = twiddle.map(|v| v.0 * 2); - - let (r0, r1) = simd_butterfly(v0.into(), v1.into(), twiddle_dbl.into()); - - let r0 = r0.to_array(); - let r1 = r1.to_array(); - for i in 0..N_LANES { - ground_truth_butterfly(&mut v0[i], &mut v1[i], twiddle[i]); - assert_eq!((v0[i], v1[i]), (r0[i], r1[i]), "mismatch at i={i}"); - } - } - - #[test] - fn test_fft3() { - let mut rng = SmallRng::seed_from_u64(0); - let values = rng.gen::<[BaseField; 8]>().map(PackedBaseField::broadcast); - let twiddles0: [BaseField; 4] = rng.gen(); - let twiddles1: [BaseField; 2] = rng.gen(); - let twiddles2: [BaseField; 1] = rng.gen(); - let twiddles0_dbl = twiddles0.map(|v| v.0 * 2); - let twiddles1_dbl = twiddles1.map(|v| v.0 * 2); - let twiddles2_dbl = twiddles2.map(|v| v.0 * 2); - - let mut res = values; - unsafe { - fft3( - transmute(res.as_ptr()), - transmute(res.as_mut_ptr()), - 0, - LOG_N_LANES as usize, - twiddles0_dbl, - twiddles1_dbl, - twiddles2_dbl, - ) - }; - - let mut expected = values.map(|v| v.to_array()[0]); - for i in 0..8 { - let j = i ^ 4; - if i > j { - continue; - } - let (mut v0, mut v1) = (expected[i], expected[j]); - ground_truth_butterfly(&mut v0, &mut v1, twiddles2[0]); - (expected[i], expected[j]) = (v0, v1); - } - for i in 0..8 { - let j = i ^ 2; - if i > j { - continue; - } - let (mut v0, mut v1) = (expected[i], expected[j]); - ground_truth_butterfly(&mut v0, &mut v1, twiddles1[i / 4]); - (expected[i], expected[j]) = (v0, v1); - } - for i in 0..8 { - let j = i ^ 1; - if i > j { - continue; - } - let (mut v0, mut v1) = (expected[i], expected[j]); - ground_truth_butterfly(&mut v0, &mut v1, twiddles0[i / 2]); - (expected[i], expected[j]) = (v0, v1); - } - for i in 0..8 { - assert_eq!( - res[i].to_array(), - [expected[i]; N_LANES], - "mismatch at i={i}" - ); - } - } - - #[test] - fn test_vecwise_butterflies() { - let domain = CanonicCoset::new(5).circle_domain(); - let twiddle_dbls = get_twiddle_dbls(domain.half_coset); - assert_eq!(twiddle_dbls.len(), 4); - let mut rng = SmallRng::seed_from_u64(0); - let values: [[BaseField; 16]; 2] = rng.gen(); - - let res = { - let (val0, val1) = simd_butterfly( - values[0].into(), - values[1].into(), - u32x16::splat(twiddle_dbls[3][0]), - ); - let (val0, val1) = vecwise_butterflies( - val0, - val1, - twiddle_dbls[0].clone().try_into().unwrap(), - twiddle_dbls[1].clone().try_into().unwrap(), - twiddle_dbls[2].clone().try_into().unwrap(), - ); - [val0.to_array(), val1.to_array()].concat() - }; - - assert_eq!(res, ground_truth_fft(domain, values.flatten())); - } - - #[test] - fn test_fft_lower() { - for log_size in 5..12 { - let domain = CanonicCoset::new(log_size).circle_domain(); - let mut rng = SmallRng::seed_from_u64(0); - let values = (0..domain.size()).map(|_| rng.gen()).collect_vec(); - let twiddle_dbls = get_twiddle_dbls(domain.half_coset); - - let mut res = values.iter().copied().collect::(); - unsafe { - fft_lower_with_vecwise( - transmute(res.data.as_ptr()), - transmute(res.data.as_mut_ptr()), - &twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(), - log_size as usize, - log_size as usize, - ) - } - - assert_eq!(res.to_cpu(), ground_truth_fft(domain, &values)); - } - } - - #[test] - fn test_fft_full() { - for log_size in CACHED_FFT_LOG_SIZE + 1..CACHED_FFT_LOG_SIZE + 3 { - let domain = CanonicCoset::new(log_size).circle_domain(); - let mut rng = SmallRng::seed_from_u64(0); - let values = (0..domain.size()).map(|_| rng.gen()).collect_vec(); - let twiddle_dbls = get_twiddle_dbls(domain.half_coset); - - let mut res = values.iter().copied().collect::(); - unsafe { - transpose_vecs(transmute(res.data.as_mut_ptr()), log_size as usize - 4); - fft( - transmute(res.data.as_ptr()), - transmute(res.data.as_mut_ptr()), - &twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(), - log_size as usize, - ); - } - - assert_eq!(res.to_cpu(), ground_truth_fft(domain, &values)); - } - } - - fn ground_truth_fft(domain: CircleDomain, values: &[BaseField]) -> Vec { - let poly = CpuCirclePoly::new(values.to_vec()); - poly.evaluate(domain).values - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/fri.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/fri.rs deleted file mode 100644 index 9721249..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/simd/fri.rs +++ /dev/null @@ -1,261 +0,0 @@ -use std::array; -use std::simd::u32x8; - -use num_traits::Zero; - -use super::m31::{PackedBaseField, LOG_N_LANES, N_LANES}; -use super::SimdBackend; -use crate::core::backend::simd::fft::compute_first_twiddles; -use crate::core::backend::simd::fft::ifft::simd_ibutterfly; -use crate::core::backend::simd::qm31::PackedSecureField; -use crate::core::backend::Column; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::secure_column::SecureColumnByCoords; -use crate::core::fri::{self, FriOps}; -use crate::core::poly::circle::SecureEvaluation; -use crate::core::poly::line::LineEvaluation; -use crate::core::poly::twiddles::TwiddleTree; -use crate::core::poly::utils::domain_line_twiddles_from_tree; -use crate::core::poly::BitReversedOrder; - -impl FriOps for SimdBackend { - fn fold_line( - eval: &LineEvaluation, - alpha: SecureField, - twiddles: &TwiddleTree, - ) -> LineEvaluation { - let log_size = eval.len().ilog2(); - if log_size <= LOG_N_LANES { - let eval = fri::fold_line(&eval.to_cpu(), alpha); - return LineEvaluation::new(eval.domain(), eval.values.into_iter().collect()); - } - - let domain = eval.domain(); - let itwiddles = domain_line_twiddles_from_tree(domain, &twiddles.itwiddles)[0]; - - let mut folded_values = SecureColumnByCoords::::zeros(1 << (log_size - 1)); - - for vec_index in 0..(1 << (log_size - 1 - LOG_N_LANES)) { - let value = unsafe { - let twiddle_dbl: [u32; 16] = - array::from_fn(|i| *itwiddles.get_unchecked(vec_index * 16 + i)); - let val0 = eval.values.packed_at(vec_index * 2).into_packed_m31s(); - let val1 = eval.values.packed_at(vec_index * 2 + 1).into_packed_m31s(); - let pairs: [_; 4] = array::from_fn(|i| { - let (a, b) = val0[i].deinterleave(val1[i]); - simd_ibutterfly(a, b, std::mem::transmute(twiddle_dbl)) - }); - let val0 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].0)); - let val1 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].1)); - val0 + PackedSecureField::broadcast(alpha) * val1 - }; - unsafe { folded_values.set_packed(vec_index, value) }; - } - - LineEvaluation::new(domain.double(), folded_values) - } - - fn fold_circle_into_line( - dst: &mut LineEvaluation, - src: &SecureEvaluation, - alpha: SecureField, - twiddles: &TwiddleTree, - ) { - let log_size = src.len().ilog2(); - assert!(log_size > LOG_N_LANES, "Evaluation too small"); - - let domain = src.domain; - let alpha_sq = alpha * alpha; - let itwiddles = domain_line_twiddles_from_tree(domain, &twiddles.itwiddles)[0]; - - for vec_index in 0..(1 << (log_size - 1 - LOG_N_LANES)) { - let value = unsafe { - // The 16 twiddles of the circle domain can be derived from the 8 twiddles of the - // next line domain. See `compute_first_twiddles()`. - let twiddle_dbl = u32x8::from_array(array::from_fn(|i| { - *itwiddles.get_unchecked(vec_index * 8 + i) - })); - let (t0, _) = compute_first_twiddles(twiddle_dbl); - let val0 = src.values.packed_at(vec_index * 2).into_packed_m31s(); - let val1 = src.values.packed_at(vec_index * 2 + 1).into_packed_m31s(); - let pairs: [_; 4] = array::from_fn(|i| { - let (a, b) = val0[i].deinterleave(val1[i]); - simd_ibutterfly(a, b, t0) - }); - let val0 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].0)); - let val1 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].1)); - val0 + PackedSecureField::broadcast(alpha) * val1 - }; - unsafe { - dst.values.set_packed( - vec_index, - dst.values.packed_at(vec_index) * PackedSecureField::broadcast(alpha_sq) - + value, - ) - }; - } - } - - fn decompose( - eval: &SecureEvaluation, - ) -> (SecureEvaluation, SecureField) { - let lambda = decomposition_coefficient(eval); - let broadcasted_lambda = PackedSecureField::broadcast(lambda); - let mut g_values = SecureColumnByCoords::::zeros(eval.len()); - - let range = eval.len().div_ceil(N_LANES); - let half_range = range / 2; - for i in 0..half_range { - let val = unsafe { eval.packed_at(i) } - broadcasted_lambda; - unsafe { g_values.set_packed(i, val) } - } - for i in half_range..range { - let val = unsafe { eval.packed_at(i) } + broadcasted_lambda; - unsafe { g_values.set_packed(i, val) } - } - - let g = SecureEvaluation::new(eval.domain, g_values); - (g, lambda) - } -} - -/// See [`decomposition_coefficient`]. -/// -/// [`decomposition_coefficient`]: crate::core::backend::cpu::CpuBackend::decomposition_coefficient -fn decomposition_coefficient( - eval: &SecureEvaluation, -) -> SecureField { - let cols = &eval.values.columns; - let [mut x_sum, mut y_sum, mut z_sum, mut w_sum] = [PackedBaseField::zero(); 4]; - - let range = cols[0].len() / N_LANES; - let (half_a, half_b) = (range / 2, range); - - for i in 0..half_a { - x_sum += cols[0].data[i]; - y_sum += cols[1].data[i]; - z_sum += cols[2].data[i]; - w_sum += cols[3].data[i]; - } - for i in half_a..half_b { - x_sum -= cols[0].data[i]; - y_sum -= cols[1].data[i]; - z_sum -= cols[2].data[i]; - w_sum -= cols[3].data[i]; - } - - let x = x_sum.pointwise_sum(); - let y = y_sum.pointwise_sum(); - let z = z_sum.pointwise_sum(); - let w = w_sum.pointwise_sum(); - - SecureField::from_m31(x, y, z, w) / BaseField::from_u32_unchecked(1 << eval.domain.log_size()) -} - -#[cfg(test)] -mod tests { - use itertools::Itertools; - use num_traits::One; - use rand::rngs::SmallRng; - use rand::{Rng, SeedableRng}; - - use crate::core::backend::simd::column::BaseColumn; - use crate::core::backend::simd::SimdBackend; - use crate::core::backend::{Column, CpuBackend}; - use crate::core::fields::m31::BaseField; - use crate::core::fields::qm31::SecureField; - use crate::core::fields::secure_column::SecureColumnByCoords; - use crate::core::fri::FriOps; - use crate::core::poly::circle::{CanonicCoset, CirclePoly, PolyOps, SecureEvaluation}; - use crate::core::poly::line::{LineDomain, LineEvaluation}; - use crate::core::poly::BitReversedOrder; - use crate::qm31; - - #[test] - fn test_fold_line() { - const LOG_SIZE: u32 = 7; - let mut rng = SmallRng::seed_from_u64(0); - let values = (0..1 << LOG_SIZE).map(|_| rng.gen()).collect_vec(); - let alpha = qm31!(1, 3, 5, 7); - let domain = LineDomain::new(CanonicCoset::new(LOG_SIZE + 1).half_coset()); - let cpu_fold = CpuBackend::fold_line( - &LineEvaluation::new(domain, values.iter().copied().collect()), - alpha, - &CpuBackend::precompute_twiddles(domain.coset()), - ); - - let avx_fold = SimdBackend::fold_line( - &LineEvaluation::new(domain, values.iter().copied().collect()), - alpha, - &SimdBackend::precompute_twiddles(domain.coset()), - ); - - assert_eq!(cpu_fold.values.to_vec(), avx_fold.values.to_vec()); - } - - #[test] - fn test_fold_circle_into_line() { - const LOG_SIZE: u32 = 7; - let values: Vec = (0..(1 << LOG_SIZE)) - .map(|i| qm31!(4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3)) - .collect(); - let alpha = qm31!(1, 3, 5, 7); - let circle_domain = CanonicCoset::new(LOG_SIZE).circle_domain(); - let line_domain = LineDomain::new(circle_domain.half_coset); - let mut cpu_fold = LineEvaluation::new( - line_domain, - SecureColumnByCoords::zeros(1 << (LOG_SIZE - 1)), - ); - CpuBackend::fold_circle_into_line( - &mut cpu_fold, - &SecureEvaluation::new(circle_domain, values.iter().copied().collect()), - alpha, - &CpuBackend::precompute_twiddles(line_domain.coset()), - ); - - let mut simd_fold = LineEvaluation::new( - line_domain, - SecureColumnByCoords::zeros(1 << (LOG_SIZE - 1)), - ); - SimdBackend::fold_circle_into_line( - &mut simd_fold, - &SecureEvaluation::new(circle_domain, values.iter().copied().collect()), - alpha, - &SimdBackend::precompute_twiddles(line_domain.coset()), - ); - - assert_eq!(cpu_fold.values.to_vec(), simd_fold.values.to_vec()); - } - - #[test] - fn decomposition_test() { - const DOMAIN_LOG_SIZE: u32 = 5; - const DOMAIN_LOG_HALF_SIZE: u32 = DOMAIN_LOG_SIZE - 1; - let s = CanonicCoset::new(DOMAIN_LOG_SIZE); - let domain = s.circle_domain(); - let mut coeffs = BaseColumn::zeros(1 << DOMAIN_LOG_SIZE); - // Polynomial is out of FFT space. - coeffs.as_mut_slice()[1 << DOMAIN_LOG_HALF_SIZE] = BaseField::one(); - let poly = CirclePoly::::new(coeffs); - let values = poly.evaluate(domain); - let avx_column = SecureColumnByCoords:: { - columns: [ - values.values.clone(), - values.values.clone(), - values.values.clone(), - values.values.clone(), - ], - }; - let avx_eval = SecureEvaluation::new(domain, avx_column.clone()); - let cpu_eval = - SecureEvaluation::::new(domain, avx_eval.to_cpu()); - let (cpu_g, cpu_lambda) = CpuBackend::decompose(&cpu_eval); - let (avx_g, avx_lambda) = SimdBackend::decompose(&avx_eval); - - assert_eq!(avx_lambda, cpu_lambda); - for i in 0..1 << DOMAIN_LOG_SIZE { - assert_eq!(avx_g.values.at(i), cpu_g.values.at(i)); - } - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/grind.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/grind.rs deleted file mode 100644 index 36721dc..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/simd/grind.rs +++ /dev/null @@ -1,95 +0,0 @@ -use std::simd::cmp::SimdPartialOrd; -use std::simd::num::SimdUint; -use std::simd::u32x16; - -use bytemuck::cast_slice; -#[cfg(feature = "parallel")] -use rayon::prelude::*; - -use super::blake2s::compress16; -use super::SimdBackend; -use crate::core::backend::simd::m31::N_LANES; -use crate::core::channel::Blake2sChannel; -#[cfg(not(target_arch = "wasm32"))] -use crate::core::channel::{Channel, Poseidon252Channel, PoseidonBLSChannel}; -use crate::core::proof_of_work::GrindOps; - -// Note: GRIND_LOW_BITS is a cap on how much extra time we need to wait for all threads to finish. -const GRIND_LOW_BITS: u32 = 20; -const GRIND_HI_BITS: u32 = 64 - GRIND_LOW_BITS; - -impl GrindOps for SimdBackend { - fn grind(channel: &Blake2sChannel, pow_bits: u32) -> u64 { - // TODO(spapini): support more than 32 bits. - assert!(pow_bits <= 32, "pow_bits > 32 is not supported"); - let digest = channel.digest(); - let digest: &[u32] = cast_slice(&digest.0[..]); - - #[cfg(not(feature = "parallel"))] - let res = (0..=(1 << GRIND_HI_BITS)) - .find_map(|hi| grind_blake(digest, hi, pow_bits)) - .expect("Grind failed to find a solution."); - - #[cfg(feature = "parallel")] - let res = (0..=(1 << GRIND_HI_BITS)) - .into_par_iter() - .find_map_any(|hi| grind_blake(digest, hi, pow_bits)) - .expect("Grind failed to find a solution."); - - res - } -} - -fn grind_blake(digest: &[u32], hi: u64, pow_bits: u32) -> Option { - let zero: u32x16 = u32x16::default(); - let pow_bits = u32x16::splat(pow_bits); - - let state: [u32x16; 8] = std::array::from_fn(|i| u32x16::splat(digest[i])); - - let mut attempt = [zero; 16]; - attempt[0] = u32x16::splat((hi << GRIND_LOW_BITS) as u32); - attempt[0] += u32x16::from(std::array::from_fn(|i| i as u32)); - attempt[1] = u32x16::splat((hi >> (32 - GRIND_LOW_BITS)) as u32); - for low in (0..(1 << GRIND_LOW_BITS)).step_by(N_LANES) { - let res = compress16(state, attempt, zero, zero, zero, zero); - let success_mask = res[0].trailing_zeros().simd_ge(pow_bits); - if success_mask.any() { - let i = success_mask.to_array().iter().position(|&x| x).unwrap(); - return Some((hi << GRIND_LOW_BITS) + low as u64 + i as u64); - } - attempt[0] += u32x16::splat(N_LANES as u32); - } - None -} - -// TODO(spapini): This is a naive implementation. Optimize it. -#[cfg(not(target_arch = "wasm32"))] -impl GrindOps for SimdBackend { - fn grind(channel: &Poseidon252Channel, pow_bits: u32) -> u64 { - let mut nonce = 0; - loop { - let mut channel = channel.clone(); - channel.mix_nonce(nonce); - if channel.trailing_zeros() >= pow_bits { - return nonce; - } - nonce += 1; - } - } -} - -// TODO(spapini): This is a naive implementation. Optimize it. -#[cfg(not(target_arch = "wasm32"))] -impl GrindOps for SimdBackend { - fn grind(channel: &PoseidonBLSChannel, pow_bits: u32) -> u64 { - let mut nonce = 0; - loop { - let mut channel = channel.clone(); - channel.mix_nonce(nonce); - if channel.trailing_zeros() >= pow_bits { - return nonce; - } - nonce += 1; - } - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/lookups/gkr.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/lookups/gkr.rs deleted file mode 100644 index 017948d..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/simd/lookups/gkr.rs +++ /dev/null @@ -1,684 +0,0 @@ -use std::iter::zip; - -use num_traits::Zero; - -use crate::core::backend::cpu::lookups::gkr::gen_eq_evals as cpu_gen_eq_evals; -use crate::core::backend::simd::column::SecureColumn; -use crate::core::backend::simd::m31::{LOG_N_LANES, N_LANES}; -use crate::core::backend::simd::qm31::PackedSecureField; -use crate::core::backend::simd::SimdBackend; -use crate::core::backend::{Column, CpuBackend}; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::lookups::gkr_prover::{ - correct_sum_as_poly_in_first_variable, EqEvals, GkrMultivariatePolyOracle, GkrOps, Layer, -}; -use crate::core::lookups::mle::Mle; -use crate::core::lookups::sumcheck::MultivariatePolyOracle; -use crate::core::lookups::utils::{Fraction, Reciprocal, UnivariatePoly}; - -impl GkrOps for SimdBackend { - #[allow(clippy::uninit_vec)] - fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Mle { - if y.len() < LOG_N_LANES as usize { - return Mle::new(cpu_gen_eq_evals(y, v).into_iter().collect()); - } - - // Start DP with CPU backend to avoid dealing with instances smaller than a SIMD vector. - let (y_last_chunk, y_rem) = y.split_last_chunk::<{ LOG_N_LANES as usize }>().unwrap(); - let initial = SecureColumn::from_iter(cpu_gen_eq_evals(y_last_chunk, v)); - assert_eq!(initial.len(), N_LANES); - - let packed_len = 1 << y_rem.len(); - let mut data = initial.data; - - data.reserve(packed_len - data.len()); - unsafe { data.set_len(packed_len) }; - - for (i, &y_j) in y_rem.iter().rev().enumerate() { - let packed_y_j = PackedSecureField::broadcast(y_j); - - let (lhs_evals, rhs_evals) = data.split_at_mut(1 << i); - - for (lhs, rhs) in zip(lhs_evals, rhs_evals) { - // Equivalent to: - // `rhs = eq(1, y_j) * lhs`, - // `lhs = eq(0, y_j) * lhs` - *rhs = *lhs * packed_y_j; - *lhs -= *rhs; - } - } - - let length = packed_len * N_LANES; - Mle::new(SecureColumn { data, length }) - } - - fn next_layer(layer: &Layer) -> Layer { - // Offload to CPU backend to avoid dealing with instances smaller than a SIMD vector. - if layer.n_variables() as u32 <= LOG_N_LANES { - return into_simd_layer(layer.to_cpu().next_layer().unwrap()); - } - - match layer { - Layer::GrandProduct(col) => next_grand_product_layer(col), - Layer::LogUpGeneric { - numerators, - denominators, - } => next_logup_generic_layer(numerators, denominators), - Layer::LogUpMultiplicities { - numerators, - denominators, - } => next_logup_multiplicities_layer(numerators, denominators), - Layer::LogUpSingles { denominators } => next_logup_singles_layer(denominators), - } - } - - fn sum_as_poly_in_first_variable( - h: &GkrMultivariatePolyOracle<'_, Self>, - claim: SecureField, - ) -> UnivariatePoly { - let n_variables = h.n_variables(); - let n_terms = 1 << n_variables.saturating_sub(1); - let eq_evals = h.eq_evals.as_ref(); - // Vector used to generate evaluations of `eq(x, y)` for `x` in the boolean hypercube. - let y = eq_evals.y(); - - // Offload to CPU backend to avoid dealing with instances smaller than a SIMD vector. - if n_terms < N_LANES { - return h.to_cpu().sum_as_poly_in_first_variable(claim); - } - - let n_packed_terms = n_terms / N_LANES; - let packed_lambda = PackedSecureField::broadcast(h.lambda); - - let (mut eval_at_0, mut eval_at_2) = match &h.input_layer { - Layer::GrandProduct(col) => eval_grand_product_sum(eq_evals, col, n_packed_terms), - Layer::LogUpGeneric { - numerators, - denominators, - } => eval_logup_generic_sum( - eq_evals, - numerators, - denominators, - n_packed_terms, - packed_lambda, - ), - Layer::LogUpMultiplicities { - numerators, - denominators, - } => eval_logup_multiplicities_sum( - eq_evals, - numerators, - denominators, - n_packed_terms, - packed_lambda, - ), - Layer::LogUpSingles { denominators } => { - eval_logup_singles_sum(eq_evals, denominators, n_packed_terms, packed_lambda) - } - }; - - eval_at_0 *= h.eq_fixed_var_correction; - eval_at_2 *= h.eq_fixed_var_correction; - correct_sum_as_poly_in_first_variable(eval_at_0, eval_at_2, claim, y, n_variables) - } -} - -/// Generates the next GKR layer for Grand Product. -/// -/// Assumption: `len(layer) > N_LANES`. -fn next_grand_product_layer(layer: &Mle) -> Layer { - assert!(layer.len() > N_LANES); - let next_layer_len = layer.len() / 2; - - let data = layer - .data - .array_chunks() - .map(|&[a, b]| { - let (evens, odds) = a.deinterleave(b); - evens * odds - }) - .collect(); - - Layer::GrandProduct(Mle::new(SecureColumn { - data, - length: next_layer_len, - })) -} - -/// Generates the next GKR layer for LogUp. -/// -/// Assumption: `len(denominators) > N_LANES`. -fn next_logup_generic_layer( - numerators: &Mle, - denominators: &Mle, -) -> Layer { - assert!(denominators.len() > N_LANES); - assert_eq!(numerators.len(), denominators.len()); - - let next_layer_len = denominators.len() / 2; - let next_layer_packed_len = next_layer_len / N_LANES; - - let mut next_numerators = Vec::with_capacity(next_layer_packed_len); - let mut next_denominators = Vec::with_capacity(next_layer_packed_len); - - for i in 0..next_layer_packed_len { - let (n_even, n_odd) = numerators.data[i * 2].deinterleave(numerators.data[i * 2 + 1]); - let (d_even, d_odd) = denominators.data[i * 2].deinterleave(denominators.data[i * 2 + 1]); - - let Fraction { - numerator, - denominator, - } = Fraction::new(n_even, d_even) + Fraction::new(n_odd, d_odd); - - next_numerators.push(numerator); - next_denominators.push(denominator); - } - - let next_numerators = SecureColumn { - data: next_numerators, - length: next_layer_len, - }; - - let next_denominators = SecureColumn { - data: next_denominators, - length: next_layer_len, - }; - - Layer::LogUpGeneric { - numerators: Mle::new(next_numerators), - denominators: Mle::new(next_denominators), - } -} - -/// Generates the next GKR layer for LogUp. -/// -/// Assumption: `len(denominators) > N_LANES`. -// TODO(andrew): Code duplication of `next_logup_generic_layer`. Consider unifying these. -fn next_logup_multiplicities_layer( - numerators: &Mle, - denominators: &Mle, -) -> Layer { - assert!(denominators.len() > N_LANES); - assert_eq!(numerators.len(), denominators.len()); - - let next_layer_len = denominators.len() / 2; - let next_layer_packed_len = next_layer_len / N_LANES; - - let mut next_numerators = Vec::with_capacity(next_layer_packed_len); - let mut next_denominators = Vec::with_capacity(next_layer_packed_len); - - for i in 0..next_layer_packed_len { - let (n_even, n_odd) = numerators.data[i * 2].deinterleave(numerators.data[i * 2 + 1]); - let (d_even, d_odd) = denominators.data[i * 2].deinterleave(denominators.data[i * 2 + 1]); - - let Fraction { - numerator, - denominator, - } = Fraction::new(n_even, d_even) + Fraction::new(n_odd, d_odd); - - next_numerators.push(numerator); - next_denominators.push(denominator); - } - - let next_numerators = SecureColumn { - data: next_numerators, - length: next_layer_len, - }; - - let next_denominators = SecureColumn { - data: next_denominators, - length: next_layer_len, - }; - - Layer::LogUpGeneric { - numerators: Mle::new(next_numerators), - denominators: Mle::new(next_denominators), - } -} - -/// Generates the next GKR layer for LogUp. -/// -/// Assumption: `len(denominators) > N_LANES`. -fn next_logup_singles_layer(denominators: &Mle) -> Layer { - assert!(denominators.len() > N_LANES); - - let next_layer_len = denominators.len() / 2; - let next_layer_packed_len = next_layer_len / N_LANES; - - let mut next_numerators = Vec::with_capacity(next_layer_packed_len); - let mut next_denominators = Vec::with_capacity(next_layer_packed_len); - - for i in 0..next_layer_packed_len { - let (d_even, d_odd) = denominators.data[i * 2].deinterleave(denominators.data[i * 2 + 1]); - - let Fraction { - numerator, - denominator, - } = Reciprocal::new(d_even) + Reciprocal::new(d_odd); - - next_numerators.push(numerator); - next_denominators.push(denominator); - } - - let next_numerators = SecureColumn { - data: next_numerators, - length: next_layer_len, - }; - - let next_denominators = SecureColumn { - data: next_denominators, - length: next_layer_len, - }; - - Layer::LogUpGeneric { - numerators: Mle::new(next_numerators), - denominators: Mle::new(next_denominators), - } -} - -/// Evaluates `sum_x eq(({0}^|r|, 0, x), y) * inp(r, t, x, 0) * inp(r, t, x, 1)` at `t=0` and `t=2`. -/// -/// Output of the form: `(eval_at_0, eval_at_2)`. -fn eval_grand_product_sum( - eq_evals: &EqEvals, - col: &Mle, - n_packed_terms: usize, -) -> (SecureField, SecureField) { - let mut packed_eval_at_0 = PackedSecureField::zero(); - let mut packed_eval_at_2 = PackedSecureField::zero(); - - for i in 0..n_packed_terms { - // Input polynomial at points `(r, {0, 1, 2}, bits(i), v, {0, 1})` - // for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. - let (inp_at_r0iv0, inp_at_r0iv1) = col.data[i * 2].deinterleave(col.data[i * 2 + 1]); - let (inp_at_r1iv0, inp_at_r1iv1) = - col.data[(n_packed_terms + i) * 2].deinterleave(col.data[(n_packed_terms + i) * 2 + 1]); - // Note `inp(r, t, x) = eq(t, 0) * inp(r, 0, x) + eq(t, 1) * inp(r, 1, x)` - // => `inp(r, 2, x) = 2 * inp(r, 1, x) - inp(r, 0, x)` - let inp_at_r2iv0 = inp_at_r1iv0.double() - inp_at_r0iv0; - let inp_at_r2iv1 = inp_at_r1iv1.double() - inp_at_r0iv1; - - // Product polynomial `prod(x) = inp(x, 0) * inp(x, 1)` at points `(r, {0, 2}, bits(i), v)`. - // for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. - let prod_at_r2iv = inp_at_r2iv0 * inp_at_r2iv1; - let prod_at_r0iv = inp_at_r0iv0 * inp_at_r0iv1; - - let eq_eval_at_0iv = eq_evals.data[i]; - packed_eval_at_0 += eq_eval_at_0iv * prod_at_r0iv; - packed_eval_at_2 += eq_eval_at_0iv * prod_at_r2iv; - } - - ( - packed_eval_at_0.pointwise_sum(), - packed_eval_at_2.pointwise_sum(), - ) -} - -fn eval_logup_generic_sum( - eq_evals: &EqEvals, - numerators: &Mle, - denominators: &Mle, - n_packed_terms: usize, - packed_lambda: PackedSecureField, -) -> (SecureField, SecureField) { - let mut packed_eval_at_0 = PackedSecureField::zero(); - let mut packed_eval_at_2 = PackedSecureField::zero(); - - let inp_numer = &numerators.data; - let inp_denom = &denominators.data; - - for i in 0..n_packed_terms { - // Input polynomials at points `(r, {0, 1, 2}, bits(i), v, {0, 1})` - // for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. - let (inp_numer_at_r0iv0, inp_numer_at_r0iv1) = - inp_numer[i * 2].deinterleave(inp_numer[i * 2 + 1]); - let (inp_denom_at_r0iv0, inp_denom_at_r0iv1) = - inp_denom[i * 2].deinterleave(inp_denom[i * 2 + 1]); - let (inp_numer_at_r1iv0, inp_numer_at_r1iv1) = inp_numer[(n_packed_terms + i) * 2] - .deinterleave(inp_numer[(n_packed_terms + i) * 2 + 1]); - let (inp_denom_at_r1iv0, inp_denom_at_r1iv1) = inp_denom[(n_packed_terms + i) * 2] - .deinterleave(inp_denom[(n_packed_terms + i) * 2 + 1]); - // Note `inp_denom(r, t, x) = eq(t, 0) * inp_denom(r, 0, x) + eq(t, 1) * inp_denom(r, 1, x)` - // => `inp_denom(r, 2, x) = 2 * inp_denom(r, 1, x) - inp_denom(r, 0, x)` - let inp_numer_at_r2iv0 = inp_numer_at_r1iv0.double() - inp_numer_at_r0iv0; - let inp_numer_at_r2iv1 = inp_numer_at_r1iv1.double() - inp_numer_at_r0iv1; - let inp_denom_at_r2iv0 = inp_denom_at_r1iv0.double() - inp_denom_at_r0iv0; - let inp_denom_at_r2iv1 = inp_denom_at_r1iv1.double() - inp_denom_at_r0iv1; - - // Fraction addition polynomials: - // - `numer(x) = inp_numer(x, 0) * inp_denom(x, 1) + inp_numer(x, 1) * inp_denom(x, 0)` - // - `denom(x) = inp_denom(x, 0) * inp_denom(x, 1)`. - // at points `(r, {0, 2}, bits(i), v)` for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. - let Fraction { - numerator: numer_at_r0iv, - denominator: denom_at_r0iv, - } = Fraction::new(inp_numer_at_r0iv0, inp_denom_at_r0iv0) - + Fraction::new(inp_numer_at_r0iv1, inp_denom_at_r0iv1); - let Fraction { - numerator: numer_at_r2iv, - denominator: denom_at_r2iv, - } = Fraction::new(inp_numer_at_r2iv0, inp_denom_at_r2iv0) - + Fraction::new(inp_numer_at_r2iv1, inp_denom_at_r2iv1); - - let eq_eval_at_0iv = eq_evals.data[i]; - packed_eval_at_0 += eq_eval_at_0iv * (numer_at_r0iv + packed_lambda * denom_at_r0iv); - packed_eval_at_2 += eq_eval_at_0iv * (numer_at_r2iv + packed_lambda * denom_at_r2iv); - } - - ( - packed_eval_at_0.pointwise_sum(), - packed_eval_at_2.pointwise_sum(), - ) -} - -// TODO(andrew): Code duplication of `eval_logup_generic_sum`. Consider unifying these. -fn eval_logup_multiplicities_sum( - eq_evals: &EqEvals, - numerators: &Mle, - denominators: &Mle, - n_packed_terms: usize, - packed_lambda: PackedSecureField, -) -> (SecureField, SecureField) { - let mut packed_eval_at_0 = PackedSecureField::zero(); - let mut packed_eval_at_2 = PackedSecureField::zero(); - - let inp_numer = &numerators.data; - let inp_denom = &denominators.data; - - for i in 0..n_packed_terms { - // Input polynomials at points `(r, {0, 1, 2}, bits(i), v, {0, 1})` - // for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. - let (inp_numer_at_r0iv0, inp_numer_at_r0iv1) = - inp_numer[i * 2].deinterleave(inp_numer[i * 2 + 1]); - let (inp_denom_at_r0iv0, inp_denom_at_r0iv1) = - inp_denom[i * 2].deinterleave(inp_denom[i * 2 + 1]); - let (inp_numer_at_r1iv0, inp_numer_at_r1iv1) = inp_numer[(n_packed_terms + i) * 2] - .deinterleave(inp_numer[(n_packed_terms + i) * 2 + 1]); - let (inp_denom_at_r1iv0, inp_denom_at_r1iv1) = inp_denom[(n_packed_terms + i) * 2] - .deinterleave(inp_denom[(n_packed_terms + i) * 2 + 1]); - // Note `inp_denom(r, t, x) = eq(t, 0) * inp_denom(r, 0, x) + eq(t, 1) * inp_denom(r, 1, x)` - // => `inp_denom(r, 2, x) = 2 * inp_denom(r, 1, x) - inp_denom(r, 0, x)` - let inp_numer_at_r2iv0 = inp_numer_at_r1iv0.double() - inp_numer_at_r0iv0; - let inp_numer_at_r2iv1 = inp_numer_at_r1iv1.double() - inp_numer_at_r0iv1; - let inp_denom_at_r2iv0 = inp_denom_at_r1iv0.double() - inp_denom_at_r0iv0; - let inp_denom_at_r2iv1 = inp_denom_at_r1iv1.double() - inp_denom_at_r0iv1; - - // Fraction addition polynomials: - // - `numer(x) = inp_numer(x, 0) * inp_denom(x, 1) + inp_numer(x, 1) * inp_denom(x, 0)` - // - `denom(x) = inp_denom(x, 0) * inp_denom(x, 1)`. - // at points `(r, {0, 2}, bits(i), v)` for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. - let Fraction { - numerator: numer_at_r0iv, - denominator: denom_at_r0iv, - } = Fraction::new(inp_numer_at_r0iv0, inp_denom_at_r0iv0) - + Fraction::new(inp_numer_at_r0iv1, inp_denom_at_r0iv1); - let Fraction { - numerator: numer_at_r2iv, - denominator: denom_at_r2iv, - } = Fraction::new(inp_numer_at_r2iv0, inp_denom_at_r2iv0) - + Fraction::new(inp_numer_at_r2iv1, inp_denom_at_r2iv1); - - let eq_eval_at_0iv = eq_evals.data[i]; - packed_eval_at_0 += eq_eval_at_0iv * (numer_at_r0iv + packed_lambda * denom_at_r0iv); - packed_eval_at_2 += eq_eval_at_0iv * (numer_at_r2iv + packed_lambda * denom_at_r2iv); - } - - ( - packed_eval_at_0.pointwise_sum(), - packed_eval_at_2.pointwise_sum(), - ) -} - -/// Evaluates `sum_x eq(({0}^|r|, 0, x), y) * (inp_denom(r, t, x, 1) + inp_denom(r, t, x, 0) + -/// lambda * inp_denom(r, t, x, 0) * inp_denom(r, t, x, 1))` at `t=0` and `t=2`. -/// -/// Output of the form: `(eval_at_0, eval_at_2)`. -fn eval_logup_singles_sum( - eq_evals: &EqEvals, - denominators: &Mle, - n_packed_terms: usize, - packed_lambda: PackedSecureField, -) -> (SecureField, SecureField) { - let mut packed_eval_at_0 = PackedSecureField::zero(); - let mut packed_eval_at_2 = PackedSecureField::zero(); - - let inp_denom = &denominators.data; - - for i in 0..n_packed_terms { - // Input polynomial at points `(r, {0, 1, 2}, bits(i), v, {0, 1})` - // for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. - let (inp_denom_at_r0iv0, inp_denom_at_r0iv1) = - inp_denom[i * 2].deinterleave(inp_denom[i * 2 + 1]); - let (inp_denom_at_r1iv0, inp_denom_at_r1iv1) = inp_denom[(n_packed_terms + i) * 2] - .deinterleave(inp_denom[(n_packed_terms + i) * 2 + 1]); - // Note `inp_denom(r, t, x) = eq(t, 0) * inp_denom(r, 0, x) + eq(t, 1) * inp_denom(r, 1, x)` - // => `inp_denom(r, 2, x) = 2 * inp_denom(r, 1, x) - inp_denom(r, 0, x)` - let inp_denom_at_r2iv0 = inp_denom_at_r1iv0.double() - inp_denom_at_r0iv0; - let inp_denom_at_r2iv1 = inp_denom_at_r1iv1.double() - inp_denom_at_r0iv1; - - // Fraction addition polynomials: - // - `numer(x) = inp_denom(x, 1) + inp_denom(x, 0)` - // - `denom(x) = inp_denom(x, 0) * inp_denom(x, 1)`. - // at points `(r, {0, 2}, bits(i), v)` for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. - let Fraction { - numerator: numer_at_r0iv, - denominator: denom_at_r0iv, - } = Reciprocal::new(inp_denom_at_r0iv0) + Reciprocal::new(inp_denom_at_r0iv1); - let Fraction { - numerator: numer_at_r2iv, - denominator: denom_at_r2iv, - } = Reciprocal::new(inp_denom_at_r2iv0) + Reciprocal::new(inp_denom_at_r2iv1); - - let eq_eval_at_0iv = eq_evals.data[i]; - packed_eval_at_0 += eq_eval_at_0iv * (numer_at_r0iv + packed_lambda * denom_at_r0iv); - packed_eval_at_2 += eq_eval_at_0iv * (numer_at_r2iv + packed_lambda * denom_at_r2iv); - } - - ( - packed_eval_at_0.pointwise_sum(), - packed_eval_at_2.pointwise_sum(), - ) -} - -fn into_simd_layer(cpu_layer: Layer) -> Layer { - match cpu_layer { - Layer::GrandProduct(mle) => { - Layer::GrandProduct(Mle::new(mle.into_evals().into_iter().collect())) - } - Layer::LogUpGeneric { - numerators, - denominators, - } => Layer::LogUpGeneric { - numerators: Mle::new(numerators.into_evals().into_iter().collect()), - denominators: Mle::new(denominators.into_evals().into_iter().collect()), - }, - Layer::LogUpMultiplicities { - numerators, - denominators, - } => Layer::LogUpMultiplicities { - numerators: Mle::new(numerators.into_evals().into_iter().collect()), - denominators: Mle::new(denominators.into_evals().into_iter().collect()), - }, - Layer::LogUpSingles { denominators } => Layer::LogUpSingles { - denominators: Mle::new(denominators.into_evals().into_iter().collect()), - }, - } -} - -#[cfg(test)] -mod tests { - use std::iter::zip; - - use num_traits::One; - use rand::rngs::SmallRng; - use rand::{Rng, SeedableRng}; - - use crate::core::backend::simd::SimdBackend; - use crate::core::backend::{Column, CpuBackend}; - use crate::core::channel::Channel; - use crate::core::fields::m31::BaseField; - use crate::core::fields::qm31::SecureField; - use crate::core::lookups::gkr_prover::{prove_batch, GkrOps, Layer}; - use crate::core::lookups::gkr_verifier::{partially_verify_batch, Gate, GkrArtifact, GkrError}; - use crate::core::lookups::mle::Mle; - use crate::core::lookups::utils::Fraction; - use crate::core::test_utils::test_channel; - - #[test] - fn gen_eq_evals_matches_cpu() { - let two = BaseField::from(2).into(); - let y = [7, 3, 5, 6, 1, 1, 9].map(|v| BaseField::from(v).into()); - let eq_evals_cpu = CpuBackend::gen_eq_evals(&y, two); - - let eq_evals_simd = SimdBackend::gen_eq_evals(&y, two); - - assert_eq!(eq_evals_simd.to_cpu(), *eq_evals_cpu); - } - - #[test] - fn gen_eq_evals_with_small_assignment_matches_cpu() { - let two = BaseField::from(2).into(); - let y = [7, 3, 5].map(|v| BaseField::from(v).into()); - let eq_evals_cpu = CpuBackend::gen_eq_evals(&y, two); - - let eq_evals_simd = SimdBackend::gen_eq_evals(&y, two); - - assert_eq!(eq_evals_simd.to_cpu(), *eq_evals_cpu); - } - - #[test] - fn grand_product_works() -> Result<(), GkrError> { - const N: usize = 1 << 8; - let values = test_channel().draw_felts(N); - let product = values.iter().product(); - let col = Mle::::new(values.into_iter().collect()); - let input_layer = Layer::GrandProduct(col.clone()); - let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]); - - let GkrArtifact { - ood_point, - claims_to_verify_by_instance, - n_variables_by_instance: _, - } = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut test_channel())?; - - assert_eq!(proof.output_claims_by_instance, [vec![product]]); - assert_eq!( - claims_to_verify_by_instance, - [vec![col.eval_at_point(&ood_point)]] - ); - Ok(()) - } - - #[test] - fn logup_with_generic_trace_works() -> Result<(), GkrError> { - const N: usize = 1 << 8; - let mut rng = SmallRng::seed_from_u64(0); - let numerators = (0..N).map(|_| rng.gen()).collect::>(); - let denominators = (0..N).map(|_| rng.gen()).collect::>(); - let sum = zip(&numerators, &denominators) - .map(|(&n, &d)| Fraction::new(n, d)) - .sum::>(); - let numerators = Mle::::new(numerators.into_iter().collect()); - let denominators = Mle::::new(denominators.into_iter().collect()); - let input_layer = Layer::LogUpGeneric { - numerators: numerators.clone(), - denominators: denominators.clone(), - }; - let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]); - - let GkrArtifact { - ood_point, - claims_to_verify_by_instance, - n_variables_by_instance: _, - } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; - - assert_eq!(claims_to_verify_by_instance.len(), 1); - assert_eq!(proof.output_claims_by_instance.len(), 1); - assert_eq!( - claims_to_verify_by_instance[0], - [ - numerators.eval_at_point(&ood_point), - denominators.eval_at_point(&ood_point) - ] - ); - assert_eq!( - proof.output_claims_by_instance[0], - [sum.numerator, sum.denominator] - ); - Ok(()) - } - - #[test] - fn logup_with_multiplicities_trace_works() -> Result<(), GkrError> { - const N: usize = 1 << 8; - let mut rng = SmallRng::seed_from_u64(0); - let numerators = (0..N).map(|_| rng.gen()).collect::>(); - let denominators = (0..N).map(|_| rng.gen()).collect::>(); - let sum = zip(&numerators, &denominators) - .map(|(&n, &d)| Fraction::new(n.into(), d)) - .sum::>(); - let numerators = Mle::::new(numerators.into_iter().collect()); - let denominators = Mle::::new(denominators.into_iter().collect()); - let input_layer = Layer::LogUpMultiplicities { - numerators: numerators.clone(), - denominators: denominators.clone(), - }; - let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]); - - let GkrArtifact { - ood_point, - claims_to_verify_by_instance, - n_variables_by_instance: _, - } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; - - assert_eq!(claims_to_verify_by_instance.len(), 1); - assert_eq!(proof.output_claims_by_instance.len(), 1); - assert_eq!( - claims_to_verify_by_instance[0], - [ - numerators.eval_at_point(&ood_point), - denominators.eval_at_point(&ood_point) - ] - ); - assert_eq!( - proof.output_claims_by_instance[0], - [sum.numerator, sum.denominator] - ); - Ok(()) - } - - #[test] - fn logup_with_singles_trace_works() -> Result<(), GkrError> { - const N: usize = 1 << 8; - let mut rng = SmallRng::seed_from_u64(0); - let denominators = (0..N).map(|_| rng.gen()).collect::>(); - let sum = denominators - .iter() - .map(|&d| Fraction::new(SecureField::one(), d)) - .sum::>(); - let denominators = Mle::::new(denominators.into_iter().collect()); - let input_layer = Layer::LogUpSingles { - denominators: denominators.clone(), - }; - let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]); - - let GkrArtifact { - ood_point, - claims_to_verify_by_instance, - n_variables_by_instance: _, - } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; - - assert_eq!(claims_to_verify_by_instance.len(), 1); - assert_eq!(proof.output_claims_by_instance.len(), 1); - assert_eq!( - claims_to_verify_by_instance[0], - [SecureField::one(), denominators.eval_at_point(&ood_point)] - ); - assert_eq!( - proof.output_claims_by_instance[0], - [sum.numerator, sum.denominator] - ); - Ok(()) - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/lookups/mle.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/lookups/mle.rs deleted file mode 100644 index 0e2fe73..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/simd/lookups/mle.rs +++ /dev/null @@ -1,132 +0,0 @@ -use core::ops::Sub; -use std::iter::zip; -use std::ops::{Add, Mul}; - -use crate::core::backend::simd::column::SecureColumn; -use crate::core::backend::simd::m31::N_LANES; -use crate::core::backend::simd::qm31::PackedSecureField; -use crate::core::backend::simd::SimdBackend; -use crate::core::backend::{Column, CpuBackend}; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::lookups::mle::{Mle, MleOps}; - -impl MleOps for SimdBackend { - fn fix_first_variable( - mle: Mle, - assignment: SecureField, - ) -> Mle { - let midpoint = mle.len() / 2; - - // Use CPU backend to avoid dealing with instances smaller than `PackedSecureField`. - if midpoint < N_LANES { - let cpu_mle = Mle::::new(mle.to_cpu()); - let cpu_res = cpu_mle.fix_first_variable(assignment); - return Mle::new(cpu_res.into_evals().into_iter().collect()); - } - - let packed_assignment = PackedSecureField::broadcast(assignment); - let packed_midpoint = midpoint / N_LANES; - let (evals_at_0x, evals_at_1x) = mle.data.split_at(packed_midpoint); - - let res = zip(evals_at_0x, evals_at_1x) - .enumerate() - // MLE at points `({0, 1}, rev(bits(i)), v)` for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. - .map(|(_i, (&packed_eval_at_0iv, &packed_eval_at_1iv))| { - fold_packed_mle_evals(packed_assignment, packed_eval_at_0iv, packed_eval_at_1iv) - }) - .collect(); - - Mle::new(res) - } -} - -impl MleOps for SimdBackend { - fn fix_first_variable( - mle: Mle, - assignment: SecureField, - ) -> Mle { - let midpoint = mle.len() / 2; - - // Use CPU backend to avoid dealing with instances smaller than `PackedSecureField`. - if midpoint < N_LANES { - let cpu_mle = Mle::::new(mle.to_cpu()); - let cpu_res = cpu_mle.fix_first_variable(assignment); - return Mle::new(cpu_res.into_evals().into_iter().collect()); - } - - let packed_midpoint = midpoint / N_LANES; - let packed_assignment = PackedSecureField::broadcast(assignment); - let mut packed_evals = mle.into_evals().data; - - for i in 0..packed_midpoint { - // MLE at points `({0, 1}, rev(bits(i)), v)` for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. - let packed_eval_at_0iv = packed_evals[i]; - let packed_eval_at_1iv = packed_evals[i + packed_midpoint]; - packed_evals[i] = - fold_packed_mle_evals(packed_assignment, packed_eval_at_0iv, packed_eval_at_1iv); - } - - packed_evals.truncate(packed_midpoint); - - let length = packed_evals.len() * N_LANES; - let data = packed_evals; - - Mle::new(SecureColumn { data, length }) - } -} - -/// Computes all `eq(0, assignment_i) * eval0_i + eq(1, assignment_i) * eval1_i`. -// TODO(andrew): Remove complex trait bounds once we have something like -// AbstractField/AbstractExtensionField traits. -fn fold_packed_mle_evals< - PackedF: Sub + Copy, - PackedEF: Mul + Add, ->( - assignment: PackedEF, - eval0: PackedF, - eval1: PackedF, -) -> PackedEF { - assignment * (eval1 - eval0) + eval0 -} - -#[cfg(test)] -mod tests { - use itertools::Itertools; - - use crate::core::backend::simd::SimdBackend; - use crate::core::backend::{Column, CpuBackend}; - use crate::core::channel::Channel; - use crate::core::fields::m31::BaseField; - use crate::core::fields::qm31::SecureField; - use crate::core::lookups::mle::Mle; - use crate::core::test_utils::test_channel; - - #[test] - fn fix_first_variable_with_secure_field_mle_matches_cpu() { - const N_VARIABLES: u32 = 8; - let values = test_channel().draw_felts(1 << N_VARIABLES); - let mle_simd = Mle::::new(values.iter().copied().collect()); - let mle_cpu = Mle::::new(values); - let random_assignment = SecureField::from_u32_unchecked(7, 12, 3, 2); - let mle_fixed_cpu = mle_cpu.fix_first_variable(random_assignment); - - let mle_fixed_simd = mle_simd.fix_first_variable(random_assignment); - - assert_eq!(mle_fixed_simd.into_evals().to_cpu(), *mle_fixed_cpu) - } - - #[test] - fn fix_first_variable_with_base_field_mle_matches_cpu() { - const N_VARIABLES: u32 = 8; - let values = (0..1 << N_VARIABLES).map(BaseField::from).collect_vec(); - let mle_simd = Mle::::new(values.iter().copied().collect()); - let mle_cpu = Mle::::new(values); - let random_assignment = SecureField::from_u32_unchecked(7, 12, 3, 2); - let mle_fixed_cpu = mle_cpu.fix_first_variable(random_assignment); - - let mle_fixed_simd = mle_simd.fix_first_variable(random_assignment); - - assert_eq!(mle_fixed_simd.into_evals().to_cpu(), *mle_fixed_cpu) - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/lookups/mod.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/lookups/mod.rs deleted file mode 100644 index 34395e9..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/simd/lookups/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -mod gkr; -mod mle; diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/m31.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/m31.rs deleted file mode 100644 index f629162..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/simd/m31.rs +++ /dev/null @@ -1,666 +0,0 @@ -use std::iter::Sum; -use std::mem::transmute; -use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; -use std::ptr; -use std::simd::cmp::SimdOrd; -use std::simd::{u32x16, Simd, Swizzle}; - -use bytemuck::{Pod, Zeroable}; -use num_traits::{One, Zero}; -use rand::distributions::{Distribution, Standard}; - -use super::qm31::PackedQM31; -use crate::core::backend::simd::utils::{InterleaveEvens, InterleaveOdds}; -use crate::core::fields::m31::{pow2147483645, BaseField, M31, P}; -use crate::core::fields::qm31::QM31; -use crate::core::fields::FieldExpOps; - -pub const LOG_N_LANES: u32 = 4; - -pub const N_LANES: usize = 1 << LOG_N_LANES; - -pub const MODULUS: Simd = Simd::from_array([P; N_LANES]); - -pub type PackedBaseField = PackedM31; - -/// Holds a vector of unreduced [`M31`] elements in the range `[0, P]`. -/// -/// Implemented with [`std::simd`] to support multiple targets (avx512, neon, wasm etc.). -// TODO: Remove `pub` visibility -#[derive(Copy, Clone, Debug)] -#[repr(transparent)] -pub struct PackedM31(Simd); - -impl PackedM31 { - /// Constructs a new instance with all vector elements set to `value`. - pub fn broadcast(M31(value): M31) -> Self { - Self(Simd::splat(value)) - } - - pub fn from_array(values: [M31; N_LANES]) -> PackedM31 { - Self(Simd::from_array(values.map(|M31(v)| v))) - } - - pub fn to_array(self) -> [M31; N_LANES] { - self.reduce().0.to_array().map(M31) - } - - /// Reduces each element of the vector to the range `[0, P)`. - fn reduce(self) -> PackedM31 { - Self(Simd::simd_min(self.0, self.0 - MODULUS)) - } - - /// Interleaves two vectors. - pub fn interleave(self, other: Self) -> (Self, Self) { - let (a, b) = self.0.interleave(other.0); - (Self(a), Self(b)) - } - - /// Deinterleaves two vectors. - pub fn deinterleave(self, other: Self) -> (Self, Self) { - let (a, b) = self.0.deinterleave(other.0); - (Self(a), Self(b)) - } - - /// Reverses the order of the elements in the vector. - pub fn reverse(self) -> Self { - Self(self.0.reverse()) - } - - /// Sums all the elements in the vector. - pub fn pointwise_sum(self) -> M31 { - self.to_array().into_iter().sum() - } - - /// Doubles each element in the vector. - pub fn double(self) -> Self { - // TODO: Make more optimal. - self + self - } - - pub fn into_simd(self) -> Simd { - self.0 - } - - /// # Safety - /// - /// Vector elements must be in the range `[0, P]`. - pub unsafe fn from_simd_unchecked(v: Simd) -> Self { - Self(v) - } - - /// # Safety - /// - /// Behavior is undefined if the pointer does not have the same alignment as - /// [`PackedM31`]. The loaded `u32` values must be in the range `[0, P]`. - pub unsafe fn load(mem_addr: *const u32) -> Self { - Self(ptr::read(mem_addr as *const u32x16)) - } - - /// # Safety - /// - /// Behavior is undefined if the pointer does not have the same alignment as - /// [`PackedM31`]. - pub unsafe fn store(self, dst: *mut u32) { - ptr::write(dst as *mut u32x16, self.0) - } -} - -impl Add for PackedM31 { - type Output = Self; - - #[inline(always)] - fn add(self, rhs: Self) -> Self::Output { - // Add word by word. Each word is in the range [0, 2P]. - let c = self.0 + rhs.0; - // Apply min(c, c-P) to each word. - // When c in [P,2P], then c-P in [0,P] which is always less than [P,2P]. - // When c in [0,P-1], then c-P in [2^32-P,2^32-1] which is always greater than [0,P-1]. - Self(Simd::simd_min(c, c - MODULUS)) - } -} - -impl AddAssign for PackedM31 { - #[inline(always)] - fn add_assign(&mut self, rhs: Self) { - *self = *self + rhs; - } -} - -impl AddAssign for PackedM31 { - #[inline(always)] - fn add_assign(&mut self, rhs: M31) { - *self = *self + PackedM31::broadcast(rhs); - } -} - -impl Mul for PackedM31 { - type Output = Self; - - #[inline(always)] - fn mul(self, rhs: Self) -> Self { - // TODO: Come up with a better approach than `cfg`ing on target_feature. - // TODO: Ensure all these branches get tested in the CI. - cfg_if::cfg_if! { - if #[cfg(all(target_feature = "neon", target_arch = "aarch64"))] { - _mul_neon(self, rhs) - } else if #[cfg(all(target_feature = "simd128", target_arch = "wasm32"))] { - _mul_wasm(self, rhs) - } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] { - _mul_avx512(self, rhs) - } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] { - _mul_avx2(self, rhs) - } else { - _mul_simd(self, rhs) - } - } - } -} - -impl Mul for PackedM31 { - type Output = Self; - - #[inline(always)] - fn mul(self, rhs: M31) -> Self::Output { - self * PackedM31::broadcast(rhs) - } -} - -impl Add for PackedM31 { - type Output = PackedM31; - - #[inline(always)] - fn add(self, rhs: M31) -> Self::Output { - PackedM31::broadcast(rhs) + self - } -} - -impl Add for PackedM31 { - type Output = PackedQM31; - - #[inline(always)] - fn add(self, rhs: QM31) -> Self::Output { - PackedQM31::broadcast(rhs) + self - } -} - -impl Mul for PackedM31 { - type Output = PackedQM31; - - #[inline(always)] - fn mul(self, rhs: QM31) -> Self::Output { - PackedQM31::broadcast(rhs) * self - } -} - -impl MulAssign for PackedM31 { - #[inline(always)] - fn mul_assign(&mut self, rhs: Self) { - *self = *self * rhs; - } -} - -impl Neg for PackedM31 { - type Output = Self; - - #[inline(always)] - fn neg(self) -> Self::Output { - Self(MODULUS - self.0) - } -} - -impl Sub for PackedM31 { - type Output = Self; - - #[inline(always)] - fn sub(self, rhs: Self) -> Self::Output { - // Subtract word by word. Each word is in the range [-P, P]. - let c = self.0 - rhs.0; - // Apply min(c, c+P) to each word. - // When c in [0,P], then c+P in [P,2P] which is always greater than [0,P]. - // When c in [2^32-P,2^32-1], then c+P in [0,P-1] which is always less than - // [2^32-P,2^32-1]. - Self(Simd::simd_min(c + MODULUS, c)) - } -} - -impl SubAssign for PackedM31 { - #[inline(always)] - fn sub_assign(&mut self, rhs: Self) { - *self = *self - rhs; - } -} - -impl Zero for PackedM31 { - fn zero() -> Self { - Self(Simd::from_array([0; N_LANES])) - } - - fn is_zero(&self) -> bool { - self.to_array().iter().all(M31::is_zero) - } -} - -impl One for PackedM31 { - fn one() -> Self { - Self(Simd::::from_array([1; N_LANES])) - } -} - -impl FieldExpOps for PackedM31 { - fn inverse(&self) -> Self { - assert!(!self.is_zero(), "0 has no inverse"); - pow2147483645(*self) - } -} - -unsafe impl Pod for PackedM31 {} - -unsafe impl Zeroable for PackedM31 { - fn zeroed() -> Self { - unsafe { core::mem::zeroed() } - } -} - -impl From<[BaseField; N_LANES]> for PackedM31 { - fn from(v: [BaseField; N_LANES]) -> Self { - Self::from_array(v) - } -} - -impl From for PackedM31 { - fn from(v: BaseField) -> Self { - Self::broadcast(v) - } -} - -impl Distribution for Standard { - fn sample(&self, rng: &mut R) -> PackedM31 { - PackedM31::from_array(rng.gen()) - } -} - -impl Sum for PackedM31 { - fn sum>(iter: I) -> Self { - iter.fold(Self::zero(), Add::add) - } -} - -/// Returns `a * b`. -#[cfg(target_arch = "aarch64")] -pub(crate) fn _mul_neon(a: PackedM31, b: PackedM31) -> PackedM31 { - use core::arch::aarch64::{int32x2_t, vqdmull_s32}; - use std::simd::u32x4; - - let [a0, a1, a2, a3, a4, a5, a6, a7]: [int32x2_t; 8] = unsafe { transmute(a) }; - let [b0, b1, b2, b3, b4, b5, b6, b7]: [int32x2_t; 8] = unsafe { transmute(b) }; - - // Each c_i contains |0|prod_lo|prod_hi|0|0|prod_lo|prod_hi|0| - let c0: u32x4 = unsafe { transmute(vqdmull_s32(a0, b0)) }; - let c1: u32x4 = unsafe { transmute(vqdmull_s32(a1, b1)) }; - let c2: u32x4 = unsafe { transmute(vqdmull_s32(a2, b2)) }; - let c3: u32x4 = unsafe { transmute(vqdmull_s32(a3, b3)) }; - let c4: u32x4 = unsafe { transmute(vqdmull_s32(a4, b4)) }; - let c5: u32x4 = unsafe { transmute(vqdmull_s32(a5, b5)) }; - let c6: u32x4 = unsafe { transmute(vqdmull_s32(a6, b6)) }; - let c7: u32x4 = unsafe { transmute(vqdmull_s32(a7, b7)) }; - - // *_lo contain `|prod_lo|0|prod_lo|0|prod_lo0|0|prod_lo|0|`. - // *_hi contain `|0|prod_hi|0|prod_hi|0|prod_hi|0|prod_hi|`. - let (mut c0_c1_lo, c0_c1_hi) = c0.deinterleave(c1); - let (mut c2_c3_lo, c2_c3_hi) = c2.deinterleave(c3); - let (mut c4_c5_lo, c4_c5_hi) = c4.deinterleave(c5); - let (mut c6_c7_lo, c6_c7_hi) = c6.deinterleave(c7); - - // *_lo contain `|0|prod_lo|0|prod_lo|0|prod_lo|0|prod_lo|`. - c0_c1_lo >>= 1; - c2_c3_lo >>= 1; - c4_c5_lo >>= 1; - c6_c7_lo >>= 1; - - let lo: PackedM31 = unsafe { transmute([c0_c1_lo, c2_c3_lo, c4_c5_lo, c6_c7_lo]) }; - let hi: PackedM31 = unsafe { transmute([c0_c1_hi, c2_c3_hi, c4_c5_hi, c6_c7_hi]) }; - - lo + hi -} - -/// Returns `a * b`. -/// -/// `b_double` should be in the range `[0, 2P]`. -#[cfg(target_arch = "aarch64")] -pub(crate) fn _mul_doubled_neon(a: PackedM31, b_double: u32x16) -> PackedM31 { - use core::arch::aarch64::{uint32x2_t, vmull_u32}; - use std::simd::u32x4; - - let [a0, a1, a2, a3, a4, a5, a6, a7]: [uint32x2_t; 8] = unsafe { transmute(a) }; - let [b0, b1, b2, b3, b4, b5, b6, b7]: [uint32x2_t; 8] = unsafe { transmute(b_double) }; - - // Each c_i contains |0|prod_lo|prod_hi|0|0|prod_lo|prod_hi|0| - let c0: u32x4 = unsafe { transmute(vmull_u32(a0, b0)) }; - let c1: u32x4 = unsafe { transmute(vmull_u32(a1, b1)) }; - let c2: u32x4 = unsafe { transmute(vmull_u32(a2, b2)) }; - let c3: u32x4 = unsafe { transmute(vmull_u32(a3, b3)) }; - let c4: u32x4 = unsafe { transmute(vmull_u32(a4, b4)) }; - let c5: u32x4 = unsafe { transmute(vmull_u32(a5, b5)) }; - let c6: u32x4 = unsafe { transmute(vmull_u32(a6, b6)) }; - let c7: u32x4 = unsafe { transmute(vmull_u32(a7, b7)) }; - - // *_lo contain `|prod_lo|0|prod_lo|0|prod_lo0|0|prod_lo|0|`. - // *_hi contain `|0|prod_hi|0|prod_hi|0|prod_hi|0|prod_hi|`. - let (mut c0_c1_lo, c0_c1_hi) = c0.deinterleave(c1); - let (mut c2_c3_lo, c2_c3_hi) = c2.deinterleave(c3); - let (mut c4_c5_lo, c4_c5_hi) = c4.deinterleave(c5); - let (mut c6_c7_lo, c6_c7_hi) = c6.deinterleave(c7); - - // *_lo contain `|0|prod_lo|0|prod_lo|0|prod_lo|0|prod_lo|`. - c0_c1_lo >>= 1; - c2_c3_lo >>= 1; - c4_c5_lo >>= 1; - c6_c7_lo >>= 1; - - let lo: PackedM31 = unsafe { transmute([c0_c1_lo, c2_c3_lo, c4_c5_lo, c6_c7_lo]) }; - let hi: PackedM31 = unsafe { transmute([c0_c1_hi, c2_c3_hi, c4_c5_hi, c6_c7_hi]) }; - - lo + hi -} - -/// Returns `a * b`. -#[cfg(target_arch = "wasm32")] -pub(crate) fn _mul_wasm(a: PackedM31, b: PackedM31) -> PackedM31 { - _mul_doubled_wasm(a, b.0 + b.0) -} - -/// Returns `a * b`. -/// -/// `b_double` should be in the range `[0, 2P]`. -#[cfg(target_arch = "wasm32")] -pub(crate) fn _mul_doubled_wasm(a: PackedM31, b_double: u32x16) -> PackedM31 { - use core::arch::wasm32::{i64x2_extmul_high_u32x4, i64x2_extmul_low_u32x4, v128}; - use std::simd::u32x4; - - let [a0, a1, a2, a3]: [v128; 4] = unsafe { transmute(a) }; - let [b_double0, b_double1, b_double2, b_double3]: [v128; 4] = unsafe { transmute(b_double) }; - - let c0_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a0, b_double0)) }; - let c0_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a0, b_double0)) }; - let c1_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a1, b_double1)) }; - let c1_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a1, b_double1)) }; - let c2_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a2, b_double2)) }; - let c2_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a2, b_double2)) }; - let c3_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a3, b_double3)) }; - let c3_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a3, b_double3)) }; - - let (mut c0_even, c0_odd) = c0_lo.deinterleave(c0_hi); - let (mut c1_even, c1_odd) = c1_lo.deinterleave(c1_hi); - let (mut c2_even, c2_odd) = c2_lo.deinterleave(c2_hi); - let (mut c3_even, c3_odd) = c3_lo.deinterleave(c3_hi); - c0_even >>= 1; - c1_even >>= 1; - c2_even >>= 1; - c3_even >>= 1; - - let even: PackedM31 = unsafe { transmute([c0_even, c1_even, c2_even, c3_even]) }; - let odd: PackedM31 = unsafe { transmute([c0_odd, c1_odd, c2_odd, c3_odd]) }; - - even + odd -} - -/// Returns `a * b`. -#[cfg(target_arch = "x86_64")] -pub(crate) fn _mul_avx512(a: PackedM31, b: PackedM31) -> PackedM31 { - _mul_doubled_avx512(a, b.0 + b.0) -} - -/// Returns `a * b`. -/// -/// `b_double` should be in the range `[0, 2P]`. -#[cfg(target_arch = "x86_64")] -pub(crate) fn _mul_doubled_avx512(a: PackedM31, b_double: u32x16) -> PackedM31 { - use std::arch::x86_64::{__m512i, _mm512_mul_epu32, _mm512_srli_epi64}; - - let a: __m512i = unsafe { transmute(a) }; - let b_double: __m512i = unsafe { transmute(b_double) }; - - // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of - // the first operand. - let a_e = a; - // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of - // the first operand. - let a_o = unsafe { _mm512_srli_epi64(a, 32) }; - - let b_dbl_e = b_double; - let b_dbl_o = unsafe { _mm512_srli_epi64(b_double, 32) }; - - // To compute prod = a * b start by multiplying a_e/odd by b_dbl_e/odd. - let prod_dbl_e: u32x16 = unsafe { transmute(_mm512_mul_epu32(a_e, b_dbl_e)) }; - let prod_dbl_o: u32x16 = unsafe { transmute(_mm512_mul_epu32(a_o, b_dbl_o)) }; - - // The result of a multiplication holds a*b in as 64-bits. - // Each 64b-bit word looks like this: - // 1 31 31 1 - // prod_dbl_e - |0|prod_e_h|prod_e_l|0| - // prod_dbl_o - |0|prod_o_h|prod_o_l|0| - - // Interleave the even words of prod_dbl_e with the even words of prod_dbl_o: - let mut prod_lo = InterleaveEvens::concat_swizzle(prod_dbl_e, prod_dbl_o); - // prod_lo - |prod_dbl_o_l|0|prod_dbl_e_l|0| - // Divide by 2: - prod_lo >>= 1; - // prod_lo - |0|prod_o_l|0|prod_e_l| - - // Interleave the odd words of prod_dbl_e with the odd words of prod_dbl_o: - let prod_hi = InterleaveOdds::concat_swizzle(prod_dbl_e, prod_dbl_o); - // prod_hi - |0|prod_o_h|0|prod_e_h| - - PackedM31(prod_lo) + PackedM31(prod_hi) -} - -/// Returns `a * b`. -#[cfg(target_arch = "x86_64")] -pub(crate) fn _mul_avx2(a: PackedM31, b: PackedM31) -> PackedM31 { - _mul_doubled_avx2(a, b.0 + b.0) -} - -/// Returns `a * b`. -/// -/// `b_double` should be in the range `[0, 2P]`. -#[cfg(target_arch = "x86_64")] -pub(crate) fn _mul_doubled_avx2(a: PackedM31, b_double: u32x16) -> PackedM31 { - use std::arch::x86_64::{__m256i, _mm256_mul_epu32, _mm256_srli_epi64}; - - let [a0, a1]: [__m256i; 2] = unsafe { transmute(a) }; - let [b0_dbl, b1_dbl]: [__m256i; 2] = unsafe { transmute(b_double) }; - - // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of - // the first operand. - let a0_e = a0; - let a1_e = a1; - // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of - // the first operand. - let a0_o = unsafe { _mm256_srli_epi64(a0, 32) }; - let a1_o = unsafe { _mm256_srli_epi64(a1, 32) }; - - let b0_dbl_e = b0_dbl; - let b1_dbl_e = b1_dbl; - let b0_dbl_o = unsafe { _mm256_srli_epi64(b0_dbl, 32) }; - let b1_dbl_o = unsafe { _mm256_srli_epi64(b1_dbl, 32) }; - - // To compute prod = a * b start by multiplying a0/1_e/odd by b0/1_e/odd. - let prod0_dbl_e = unsafe { _mm256_mul_epu32(a0_e, b0_dbl_e) }; - let prod0_dbl_o = unsafe { _mm256_mul_epu32(a0_o, b0_dbl_o) }; - let prod1_dbl_e = unsafe { _mm256_mul_epu32(a1_e, b1_dbl_e) }; - let prod1_dbl_o = unsafe { _mm256_mul_epu32(a1_o, b1_dbl_o) }; - - let prod_dbl_e: u32x16 = unsafe { transmute([prod0_dbl_e, prod1_dbl_e]) }; - let prod_dbl_o: u32x16 = unsafe { transmute([prod0_dbl_o, prod1_dbl_o]) }; - - // The result of a multiplication holds a*b in as 64-bits. - // Each 64b-bit word looks like this: - // 1 31 31 1 - // prod_dbl_e - |0|prod_e_h|prod_e_l|0| - // prod_dbl_o - |0|prod_o_h|prod_o_l|0| - - // Interleave the even words of prod_dbl_e with the even words of prod_dbl_o: - let mut prod_lo = InterleaveEvens::concat_swizzle(prod_dbl_e, prod_dbl_o); - // prod_lo - |prod_dbl_o_l|0|prod_dbl_e_l|0| - // Divide by 2: - prod_lo >>= 1; - // prod_lo - |0|prod_o_l|0|prod_e_l| - - // Interleave the odd words of prod_dbl_e with the odd words of prod_dbl_o: - let prod_hi = InterleaveOdds::concat_swizzle(prod_dbl_e, prod_dbl_o); - // prod_hi - |0|prod_o_h|0|prod_e_h| - - PackedM31(prod_lo) + PackedM31(prod_hi) -} - -/// Returns `a * b`. -/// -/// Should only be used in the absence of a platform specific implementation. -pub(crate) fn _mul_simd(a: PackedM31, b: PackedM31) -> PackedM31 { - _mul_doubled_simd(a, b.0 + b.0) -} - -/// Returns `a * b`. -/// -/// Should only be used in the absence of a platform specific implementation. -/// -/// `b_double` should be in the range `[0, 2P]`. -pub(crate) fn _mul_doubled_simd(a: PackedM31, b_double: u32x16) -> PackedM31 { - const MASK_EVENS: Simd = Simd::from_array([0xFFFFFFFF; { N_LANES / 2 }]); - - // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of - // the first operand. - let a_e = unsafe { transmute::<_, Simd>(a.0) & MASK_EVENS }; - // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of - // the first operand. - let a_o = unsafe { transmute::<_, Simd>(a) >> 32 }; - - let b_dbl_e = unsafe { transmute::<_, Simd>(b_double) & MASK_EVENS }; - let b_dbl_o = unsafe { transmute::<_, Simd>(b_double) >> 32 }; - - // To compute prod = a * b start by multiplying - // a_e/o by b_dbl_e/o. - let prod_e_dbl = a_e * b_dbl_e; - let prod_o_dbl = a_o * b_dbl_o; - - // The result of a multiplication holds a*b in as 64-bits. - // Each 64b-bit word looks like this: - // 1 31 31 1 - // prod_e_dbl - |0|prod_e_h|prod_e_l|0| - // prod_o_dbl - |0|prod_o_h|prod_o_l|0| - - // Interleave the even words of prod_e_dbl with the even words of prod_o_dbl: - // let prod_lows = _mm512_permutex2var_epi32(prod_e_dbl, EVENS_INTERLEAVE_EVENS, - // prod_o_dbl); - // prod_ls - |prod_o_l|0|prod_e_l|0| - let mut prod_lows = InterleaveEvens::concat_swizzle( - unsafe { transmute::<_, Simd>(prod_e_dbl) }, - unsafe { transmute::<_, Simd>(prod_o_dbl) }, - ); - // Divide by 2: - prod_lows >>= 1; - // prod_ls - |0|prod_o_l|0|prod_e_l| - - // Interleave the odd words of prod_e_dbl with the odd words of prod_o_dbl: - let prod_highs = InterleaveOdds::concat_swizzle( - unsafe { transmute::<_, Simd>(prod_e_dbl) }, - unsafe { transmute::<_, Simd>(prod_o_dbl) }, - ); - - // prod_hs - |0|prod_o_h|0|prod_e_h| - PackedM31(prod_lows) + PackedM31(prod_highs) -} - -#[cfg(test)] -mod tests { - use std::array; - - use aligned::{Aligned, A64}; - use rand::rngs::SmallRng; - use rand::{Rng, SeedableRng}; - - use super::PackedM31; - use crate::core::fields::m31::BaseField; - use crate::core::fields::FieldExpOps; - - #[test] - fn addition_works() { - let mut rng = SmallRng::seed_from_u64(0); - let lhs = rng.gen(); - let rhs = rng.gen(); - let packed_lhs = PackedM31::from_array(lhs); - let packed_rhs = PackedM31::from_array(rhs); - - let res = packed_lhs + packed_rhs; - - assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] + rhs[i])); - } - - #[test] - fn subtraction_works() { - let mut rng = SmallRng::seed_from_u64(0); - let lhs = rng.gen(); - let rhs = rng.gen(); - let packed_lhs = PackedM31::from_array(lhs); - let packed_rhs = PackedM31::from_array(rhs); - - let res = packed_lhs - packed_rhs; - - assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] - rhs[i])); - } - - #[test] - fn multiplication_works() { - let mut rng = SmallRng::seed_from_u64(0); - let lhs = rng.gen(); - let rhs = rng.gen(); - let packed_lhs = PackedM31::from_array(lhs); - let packed_rhs = PackedM31::from_array(rhs); - - let res = packed_lhs * packed_rhs; - - assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] * rhs[i])); - } - - #[test] - fn negation_works() { - let mut rng = SmallRng::seed_from_u64(0); - let values = rng.gen(); - let packed_values = PackedM31::from_array(values); - - let res = -packed_values; - - assert_eq!(res.to_array(), array::from_fn(|i| -values[i])); - } - - #[test] - fn load_works() { - let v: Aligned = Aligned(array::from_fn(|i| i as u32)); - - let res = unsafe { PackedM31::load(v.as_ptr()) }; - - assert_eq!(res.to_array().map(|v| v.0), *v); - } - - #[test] - fn store_works() { - let v = PackedM31::from_array(array::from_fn(BaseField::from)); - - let mut res: Aligned = Aligned([0; 16]); - unsafe { v.store(res.as_mut_ptr()) }; - - assert_eq!(*res, v.to_array().map(|v| v.0)); - } - - #[test] - fn inverse_works() { - let mut rng = SmallRng::seed_from_u64(0); - let values = rng.gen(); - let packed_values = PackedM31::from_array(values); - - let res = packed_values.inverse(); - - assert_eq!(res.to_array(), array::from_fn(|i| values[i].inverse())); - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/mod.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/mod.rs deleted file mode 100644 index 49c7f4a..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/simd/mod.rs +++ /dev/null @@ -1,41 +0,0 @@ -use serde::{Deserialize, Serialize}; - -use super::{Backend, BackendForChannel}; -use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; -#[cfg(not(target_arch = "wasm32"))] -use crate::core::vcs::poseidon252_merkle::Poseidon252MerkleChannel; -#[cfg(not(target_arch = "wasm32"))] -use crate::core::vcs::poseidon_bls_merkle::PoseidonBLSMerkleChannel; - -pub mod accumulation; -pub mod bit_reverse; -pub mod blake2s; -pub mod circle; -pub mod cm31; -pub mod column; -pub mod domain; -pub mod fft; -pub mod fri; -mod grind; -pub mod lookups; -pub mod m31; -#[cfg(not(target_arch = "wasm32"))] -pub mod poseidon252; -pub mod prefix_sum; -pub mod qm31; -pub mod quotients; -mod utils; -pub mod very_packed_m31; -#[cfg(not(target_arch = "wasm32"))] -pub mod poseidon_bls; - -#[derive(Copy, Clone, Debug, Deserialize, Serialize)] -pub struct SimdBackend; - -impl Backend for SimdBackend {} -impl BackendForChannel for SimdBackend {} -#[cfg(not(target_arch = "wasm32"))] -impl BackendForChannel for SimdBackend {} - -#[cfg(not(target_arch = "wasm32"))] -impl BackendForChannel for SimdBackend {} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/poseidon252.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/poseidon252.rs deleted file mode 100644 index b001481..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/simd/poseidon252.rs +++ /dev/null @@ -1,36 +0,0 @@ -use itertools::Itertools; -use starknet_ff::FieldElement as FieldElement252; - -use super::SimdBackend; -use crate::core::backend::{Col, Column, ColumnOps}; -use crate::core::fields::m31::BaseField; -#[cfg(not(target_arch = "wasm32"))] -use crate::core::vcs::ops::MerkleHasher; -use crate::core::vcs::ops::MerkleOps; -use crate::core::vcs::poseidon252_merkle::Poseidon252MerkleHasher; - -impl ColumnOps for SimdBackend { - type Column = Vec; - - fn bit_reverse_column(_column: &mut Self::Column) { - unimplemented!() - } -} - -impl MerkleOps for SimdBackend { - // TODO(ShaharS): replace with SIMD implementation. - fn commit_on_layer( - log_size: u32, - prev_layer: Option<&Vec>, - columns: &[&Col], - ) -> Vec { - (0..(1 << log_size)) - .map(|i| { - Poseidon252MerkleHasher::hash_node( - prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])), - &columns.iter().map(|column| column.at(i)).collect_vec(), - ) - }) - .collect() - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/poseidon_bls.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/poseidon_bls.rs deleted file mode 100644 index 10c5ec9..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/simd/poseidon_bls.rs +++ /dev/null @@ -1,36 +0,0 @@ -use itertools::Itertools; -use ark_bls12_381::Fr as BlsFr; - -use super::SimdBackend; -use crate::core::backend::{Col, Column, ColumnOps}; -use crate::core::fields::m31::BaseField; -#[cfg(not(target_arch = "wasm32"))] -use crate::core::vcs::ops::MerkleHasher; -use crate::core::vcs::ops::MerkleOps; -use crate::core::vcs::poseidon_bls_merkle::PoseidonBLSMerkleHasher; - -impl ColumnOps for SimdBackend { - type Column = Vec; - - fn bit_reverse_column(_column: &mut Self::Column) { - unimplemented!() - } -} - -impl MerkleOps for SimdBackend { - // TODO(ShaharS): replace with SIMD implementation. - fn commit_on_layer( - log_size: u32, - prev_layer: Option<&Vec>, - columns: &[&Col], - ) -> Vec { - (0..(1 << log_size)) - .map(|i| { - PoseidonBLSMerkleHasher::hash_node( - prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])), - &columns.iter().map(|column| column.at(i)).collect_vec(), - ) - }) - .collect() - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/prefix_sum.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/prefix_sum.rs deleted file mode 100644 index 652b484..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/simd/prefix_sum.rs +++ /dev/null @@ -1,188 +0,0 @@ -use std::iter::zip; -use std::ops::{AddAssign, Sub}; - -use itertools::{izip, Itertools}; -use num_traits::Zero; - -use crate::core::backend::simd::m31::{PackedBaseField, N_LANES}; -use crate::core::backend::simd::SimdBackend; -use crate::core::backend::{Col, Column}; -use crate::core::fields::m31::BaseField; -use crate::core::utils::{ - bit_reverse, circle_domain_order_to_coset_order, coset_order_to_circle_domain_order, -}; - -/// Performs a inclusive prefix sum on values in `Coset` order when provided -/// with evaluations in bit-reversed `CircleDomain` order. -/// -/// Based on parallel Blelloch prefix sum: -/// -pub fn inclusive_prefix_sum( - bit_rev_circle_domain_evals: Col, -) -> Col { - if bit_rev_circle_domain_evals.len() < N_LANES * 4 { - return inclusive_prefix_sum_slow(bit_rev_circle_domain_evals); - } - - let mut res = bit_rev_circle_domain_evals; - let packed_len = res.data.len(); - let (l_half, r_half) = res.data.split_at_mut(packed_len / 2); - - // Up Sweep - // ======== - // Handle the first two up sweep rounds manually. - // Required due different ordering of `CircleDomain` and `Coset`. - // Evaluations are provided in bit-reversed `CircleDomain` order. - for ([l0, l1], [r0, r1]) in izip!(l_half.array_chunks_mut(), r_half.array_chunks_mut().rev()) { - let (mut half_coset0_lo, half_coset1_hi_rev) = l0.deinterleave(*l1); - let half_coset1_hi = half_coset1_hi_rev.reverse(); - let (mut half_coset0_hi, half_coset1_lo_rev) = r0.deinterleave(*r1); - let half_coset1_lo = half_coset1_lo_rev.reverse(); - up_sweep_val(&mut half_coset0_lo, half_coset1_lo); - up_sweep_val(&mut half_coset0_hi, half_coset1_hi); - *l0 = half_coset0_lo; - *l1 = half_coset0_hi; - *r0 = half_coset1_lo; - *r1 = half_coset1_hi; - } - let half_coset0_sums: &mut [PackedBaseField] = l_half; - for i in 0..half_coset0_sums.len() / 2 { - let lo_index = i * 2; - let hi_index = half_coset0_sums.len() - 1 - i * 2; - let hi = half_coset0_sums[hi_index]; - up_sweep_val(&mut half_coset0_sums[lo_index], hi) - } - // Handle remaining up sweep rounds. - let mut chunk_size = half_coset0_sums.len() / 2; - while chunk_size > 1 { - let (lows, highs) = half_coset0_sums.split_at_mut(chunk_size); - zip(lows.array_chunks_mut(), highs.array_chunks()) - .for_each(|([lo, _], [hi, _])| up_sweep_val(lo, *hi)); - chunk_size /= 2; - } - // Up sweep the last SIMD vector. - let mut first_vec = half_coset0_sums.first().unwrap().to_array(); - let mut chunk_size = first_vec.len() / 2; - while chunk_size > 0 { - let (lows, highs) = first_vec.split_at_mut(chunk_size); - zip(lows, highs).for_each(|(lo, hi)| up_sweep_val(lo, *hi)); - chunk_size /= 2; - } - - // Down Sweep - // ========== - // Down sweep the last SIMD vector. - let mut chunk_size = 1; - while chunk_size < first_vec.len() { - let (lows, highs) = first_vec.split_at_mut(chunk_size); - zip(lows, highs).for_each(|(lo, hi)| down_sweep_val(lo, hi)); - chunk_size *= 2; - } - // Re-insert the SIMD vector. - *half_coset0_sums.first_mut().unwrap() = first_vec.into(); - // Handle remaining down sweep rounds (except first two). - let mut chunk_size = 2; - while chunk_size < half_coset0_sums.len() { - let (lows, highs) = half_coset0_sums.split_at_mut(chunk_size); - zip(lows.array_chunks_mut(), highs.array_chunks_mut()) - .for_each(|([lo, _], [hi, _])| down_sweep_val(lo, hi)); - chunk_size *= 2; - } - // Handle last two down sweep rounds manually. - // Required due different ordering of `CircleDomain` and `Coset`. - // Evaluations must be returned in bit-reversed `CircleDomain` order. - for i in 0..half_coset0_sums.len() / 2 { - let lo_index = i * 2; - let hi_index = half_coset0_sums.len() - 1 - i * 2; - let (mut lo, mut hi) = (half_coset0_sums[lo_index], half_coset0_sums[hi_index]); - down_sweep_val(&mut lo, &mut hi); - (half_coset0_sums[lo_index], half_coset0_sums[hi_index]) = (lo, hi); - } - for ([l0, l1], [r0, r1]) in izip!(l_half.array_chunks_mut(), r_half.array_chunks_mut().rev()) { - let mut half_coset0_lo = *l0; - let mut half_coset1_lo = *r0; - down_sweep_val(&mut half_coset0_lo, &mut half_coset1_lo); - let mut half_coset0_hi = *l1; - let mut half_coset1_hi = *r1; - down_sweep_val(&mut half_coset0_hi, &mut half_coset1_hi); - (*l0, *l1) = half_coset0_lo.interleave(half_coset1_hi.reverse()); - (*r0, *r1) = half_coset0_hi.interleave(half_coset1_lo.reverse()); - } - - res -} - -fn up_sweep_val(lo: &mut F, hi: F) { - *lo += hi; -} - -fn down_sweep_val + Copy>(lo: &mut F, hi: &mut F) { - (*lo, *hi) = (*lo - *hi, *lo) -} - -fn inclusive_prefix_sum_slow( - bit_rev_circle_domain_evals: Col, -) -> Col { - // Obtain values in coset order. - let mut coset_order_eval = bit_rev_circle_domain_evals.into_cpu_vec(); - bit_reverse(&mut coset_order_eval); - coset_order_eval = circle_domain_order_to_coset_order(&coset_order_eval); - let coset_order_prefix_sum = coset_order_eval - .into_iter() - .scan(BaseField::zero(), |acc, v| { - *acc += v; - Some(*acc) - }) - .collect_vec(); - let mut circle_domain_order_eval = coset_order_to_circle_domain_order(&coset_order_prefix_sum); - bit_reverse(&mut circle_domain_order_eval); - circle_domain_order_eval.into_iter().collect() -} - -#[cfg(test)] -mod tests { - use rand::rngs::SmallRng; - use rand::{Rng, SeedableRng}; - use test_log::test; - - use super::inclusive_prefix_sum; - use crate::core::backend::simd::column::BaseColumn; - use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum_slow; - use crate::core::backend::Column; - - #[test] - fn exclusive_prefix_sum_simd_with_log_size_3_works() { - const LOG_N: u32 = 3; - let mut rng = SmallRng::seed_from_u64(0); - let evals: BaseColumn = (0..1 << LOG_N).map(|_| rng.gen()).collect(); - let expected = inclusive_prefix_sum_slow(evals.clone()); - - let res = inclusive_prefix_sum(evals); - - assert_eq!(res.to_cpu(), expected.to_cpu()); - } - - #[test] - fn exclusive_prefix_sum_simd_with_log_size_6_works() { - const LOG_N: u32 = 6; - let mut rng = SmallRng::seed_from_u64(0); - let evals: BaseColumn = (0..1 << LOG_N).map(|_| rng.gen()).collect(); - let expected = inclusive_prefix_sum_slow(evals.clone()); - - let res = inclusive_prefix_sum(evals); - - assert_eq!(res.to_cpu(), expected.to_cpu()); - } - - #[test] - fn exclusive_prefix_sum_simd_with_log_size_8_works() { - const LOG_N: u32 = 8; - let mut rng = SmallRng::seed_from_u64(0); - let evals: BaseColumn = (0..1 << LOG_N).map(|_| rng.gen()).collect(); - let expected = inclusive_prefix_sum_slow(evals.clone()); - - let res = inclusive_prefix_sum(evals); - - assert_eq!(res.to_cpu(), expected.to_cpu()); - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/qm31.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/qm31.rs deleted file mode 100644 index 13d03ce..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/simd/qm31.rs +++ /dev/null @@ -1,357 +0,0 @@ -use std::array; -use std::iter::Sum; -use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; - -use bytemuck::{Pod, Zeroable}; -use num_traits::{One, Zero}; -use rand::distributions::{Distribution, Standard}; - -use super::cm31::PackedCM31; -use super::m31::{PackedM31, N_LANES}; -use crate::core::fields::qm31::QM31; -use crate::core::fields::FieldExpOps; - -pub type PackedSecureField = PackedQM31; - -/// SIMD implementation of [`QM31`]. -#[derive(Copy, Clone, Debug)] -pub struct PackedQM31(pub [PackedCM31; 2]); - -impl PackedQM31 { - /// Constructs a new instance with all vector elements set to `value`. - pub fn broadcast(value: QM31) -> Self { - Self([ - PackedCM31::broadcast(value.0), - PackedCM31::broadcast(value.1), - ]) - } - - /// Returns all `a` values such that each vector element is represented as `a + bu`. - pub fn a(&self) -> PackedCM31 { - self.0[0] - } - - /// Returns all `b` values such that each vector element is represented as `a + bu`. - pub fn b(&self) -> PackedCM31 { - self.0[1] - } - - pub fn to_array(&self) -> [QM31; N_LANES] { - let a = self.a().to_array(); - let b = self.b().to_array(); - array::from_fn(|i| QM31(a[i], b[i])) - } - - pub fn from_array(values: [QM31; N_LANES]) -> Self { - let a = values.map(|v| v.0); - let b = values.map(|v| v.1); - Self([PackedCM31::from_array(a), PackedCM31::from_array(b)]) - } - - /// Interleaves two vectors. - pub fn interleave(self, other: Self) -> (Self, Self) { - let Self([a_evens, b_evens]) = self; - let Self([a_odds, b_odds]) = other; - let (a_lhs, a_rhs) = a_evens.interleave(a_odds); - let (b_lhs, b_rhs) = b_evens.interleave(b_odds); - (Self([a_lhs, b_lhs]), Self([a_rhs, b_rhs])) - } - - /// Deinterleaves two vectors. - pub fn deinterleave(self, other: Self) -> (Self, Self) { - let Self([a_lhs, b_lhs]) = self; - let Self([a_rhs, b_rhs]) = other; - let (a_evens, a_odds) = a_lhs.deinterleave(a_rhs); - let (b_evens, b_odds) = b_lhs.deinterleave(b_rhs); - (Self([a_evens, b_evens]), Self([a_odds, b_odds])) - } - - /// Sums all the elements in the vector. - pub fn pointwise_sum(self) -> QM31 { - self.to_array().into_iter().sum() - } - - /// Doubles each element in the vector. - pub fn double(self) -> Self { - let Self([a, b]) = self; - Self([a.double(), b.double()]) - } - - /// Returns vectors `a, b, c, d` such that element `i` is represented as - /// `QM31(a_i, b_i, c_i, d_i)`. - pub fn into_packed_m31s(self) -> [PackedM31; 4] { - let Self([PackedCM31([a, b]), PackedCM31([c, d])]) = self; - [a, b, c, d] - } - - /// Creates an instance from vectors `a, b, c, d` such that element `i` - /// is represented as `QM31(a_i, b_i, c_i, d_i)`. - pub fn from_packed_m31s([a, b, c, d]: [PackedM31; 4]) -> Self { - Self([PackedCM31([a, b]), PackedCM31([c, d])]) - } -} - -impl Add for PackedQM31 { - type Output = Self; - - fn add(self, rhs: Self) -> Self::Output { - Self([self.a() + rhs.a(), self.b() + rhs.b()]) - } -} - -impl Sub for PackedQM31 { - type Output = Self; - - fn sub(self, rhs: Self) -> Self::Output { - Self([self.a() - rhs.a(), self.b() - rhs.b()]) - } -} - -impl Mul for PackedQM31 { - type Output = Self; - - fn mul(self, rhs: Self) -> Self::Output { - // Compute using Karatsuba. - // (a + ub) * (c + ud) = - // (ac + (2+i)bd) + (ad + bc)u = - // ac + 2bd + ibd + (ad + bc)u. - let ac = self.a() * rhs.a(); - let bd = self.b() * rhs.b(); - let bd_times_1_plus_i = PackedCM31([bd.a() - bd.b(), bd.a() + bd.b()]); - // Computes ac + bd. - let ac_p_bd = ac + bd; - // Computes ad + bc. - let ad_p_bc = (self.a() + self.b()) * (rhs.a() + rhs.b()) - ac_p_bd; - // ac + 2bd + ibd = - // ac + bd + bd + ibd - let l = PackedCM31([ - ac_p_bd.a() + bd_times_1_plus_i.a(), - ac_p_bd.b() + bd_times_1_plus_i.b(), - ]); - Self([l, ad_p_bc]) - } -} - -impl Zero for PackedQM31 { - fn zero() -> Self { - Self([PackedCM31::zero(), PackedCM31::zero()]) - } - - fn is_zero(&self) -> bool { - self.a().is_zero() && self.b().is_zero() - } -} - -impl One for PackedQM31 { - fn one() -> Self { - Self([PackedCM31::one(), PackedCM31::zero()]) - } -} - -impl AddAssign for PackedQM31 { - fn add_assign(&mut self, rhs: Self) { - *self = *self + rhs; - } -} - -impl MulAssign for PackedQM31 { - fn mul_assign(&mut self, rhs: Self) { - *self = *self * rhs; - } -} - -impl FieldExpOps for PackedQM31 { - fn inverse(&self) -> Self { - assert!(!self.is_zero(), "0 has no inverse"); - // (a + bu)^-1 = (a - bu) / (a^2 - (2+i)b^2). - let b2 = self.b().square(); - let ib2 = PackedCM31([-b2.b(), b2.a()]); - let denom = self.a().square() - (b2 + b2 + ib2); - let denom_inverse = denom.inverse(); - Self([self.a() * denom_inverse, -self.b() * denom_inverse]) - } -} - -impl Add for PackedQM31 { - type Output = Self; - - fn add(self, rhs: PackedM31) -> Self::Output { - Self([self.a() + rhs, self.b()]) - } -} - -impl Mul for PackedQM31 { - type Output = Self; - - fn mul(self, rhs: PackedM31) -> Self::Output { - let Self([a, b]) = self; - Self([a * rhs, b * rhs]) - } -} - -impl Mul for PackedQM31 { - type Output = Self; - - fn mul(self, rhs: PackedCM31) -> Self::Output { - let Self([a, b]) = self; - Self([a * rhs, b * rhs]) - } -} - -impl Sub for PackedQM31 { - type Output = Self; - - fn sub(self, rhs: PackedM31) -> Self::Output { - let Self([a, b]) = self; - Self([a - rhs, b]) - } -} - -impl Add for PackedQM31 { - type Output = Self; - - fn add(self, rhs: QM31) -> Self::Output { - self + PackedQM31::broadcast(rhs) - } -} - -impl Sub for PackedQM31 { - type Output = Self; - - fn sub(self, rhs: QM31) -> Self::Output { - self - PackedQM31::broadcast(rhs) - } -} - -impl Mul for PackedQM31 { - type Output = Self; - - fn mul(self, rhs: QM31) -> Self::Output { - self * PackedQM31::broadcast(rhs) - } -} - -impl SubAssign for PackedQM31 { - fn sub_assign(&mut self, rhs: Self) { - *self = *self - rhs; - } -} - -unsafe impl Pod for PackedQM31 {} - -unsafe impl Zeroable for PackedQM31 { - fn zeroed() -> Self { - unsafe { core::mem::zeroed() } - } -} - -impl Sum for PackedQM31 { - fn sum(mut iter: I) -> Self - where - I: Iterator, - { - let first = iter.next().unwrap_or_else(Self::zero); - iter.fold(first, |a, b| a + b) - } -} - -impl<'a> Sum<&'a Self> for PackedQM31 { - fn sum(iter: I) -> Self - where - I: Iterator, - { - iter.copied().sum() - } -} - -impl Neg for PackedQM31 { - type Output = Self; - - fn neg(self) -> Self::Output { - let Self([a, b]) = self; - Self([-a, -b]) - } -} - -impl Distribution for Standard { - fn sample(&self, rng: &mut R) -> PackedQM31 { - PackedQM31::from_array(rng.gen()) - } -} - -impl From for PackedQM31 { - fn from(value: PackedM31) -> Self { - PackedQM31::from_packed_m31s([ - value, - PackedM31::zero(), - PackedM31::zero(), - PackedM31::zero(), - ]) - } -} - -impl From for PackedQM31 { - fn from(value: QM31) -> Self { - PackedQM31::broadcast(value) - } -} - -#[cfg(test)] -mod tests { - use std::array; - - use rand::rngs::SmallRng; - use rand::{Rng, SeedableRng}; - - use crate::core::backend::simd::qm31::PackedQM31; - - #[test] - fn addition_works() { - let mut rng = SmallRng::seed_from_u64(0); - let lhs = rng.gen(); - let rhs = rng.gen(); - let packed_lhs = PackedQM31::from_array(lhs); - let packed_rhs = PackedQM31::from_array(rhs); - - let res = packed_lhs + packed_rhs; - - assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] + rhs[i])); - } - - #[test] - fn subtraction_works() { - let mut rng = SmallRng::seed_from_u64(0); - let lhs = rng.gen(); - let rhs = rng.gen(); - let packed_lhs = PackedQM31::from_array(lhs); - let packed_rhs = PackedQM31::from_array(rhs); - - let res = packed_lhs - packed_rhs; - - assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] - rhs[i])); - } - - #[test] - fn multiplication_works() { - let mut rng = SmallRng::seed_from_u64(0); - let lhs = rng.gen(); - let rhs = rng.gen(); - let packed_lhs = PackedQM31::from_array(lhs); - let packed_rhs = PackedQM31::from_array(rhs); - - let res = packed_lhs * packed_rhs; - - assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] * rhs[i])); - } - - #[test] - fn negation_works() { - let mut rng = SmallRng::seed_from_u64(0); - let values = rng.gen(); - let packed_values = PackedQM31::from_array(values); - - let res = -packed_values; - - assert_eq!(res.to_array(), values.map(|v| -v)); - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/quotients.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/quotients.rs deleted file mode 100644 index 3cb664a..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/simd/quotients.rs +++ /dev/null @@ -1,314 +0,0 @@ -use itertools::{izip, zip_eq, Itertools}; -use num_traits::Zero; -use tracing::{span, Level}; - -use super::cm31::PackedCM31; -use super::column::CM31Column; -use super::domain::CircleDomainBitRevIterator; -use super::m31::{PackedBaseField, LOG_N_LANES, N_LANES}; -use super::qm31::PackedSecureField; -use super::SimdBackend; -use crate::core::backend::cpu::quotients::{batch_random_coeffs, column_line_coeffs}; -use crate::core::backend::Column; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::secure_column::{SecureColumnByCoords, SECURE_EXTENSION_DEGREE}; -use crate::core::fields::FieldExpOps; -use crate::core::pcs::quotients::{ColumnSampleBatch, QuotientOps}; -use crate::core::poly::circle::{CircleDomain, CircleEvaluation, PolyOps, SecureEvaluation}; -use crate::core::poly::BitReversedOrder; -use crate::core::utils::bit_reverse; - -pub struct QuotientConstants { - pub line_coeffs: Vec>, - pub batch_random_coeffs: Vec, - pub denominator_inverses: Vec, -} - -impl QuotientOps for SimdBackend { - fn accumulate_quotients( - domain: CircleDomain, - columns: &[&CircleEvaluation], - random_coeff: SecureField, - sample_batches: &[ColumnSampleBatch], - log_blowup_factor: u32, - ) -> SecureEvaluation { - // Split the domain into a subdomain and a shift coset. - // TODO(spapini): Move to the caller when Columns support slices. - let (subdomain, mut subdomain_shifts) = domain.split(log_blowup_factor); - - // Bit reverse the shifts. - // Since we traverse the domain in bit-reversed order, we need bit-reverse the shifts. - // To see why, consider the index of a point in the natural order of the domain - // (least to most): - // b0 b1 b2 b3 b4 b5 - // b0 adds G, b1 adds 2G, etc.. (b5 is special and flips the sign of the point). - // Splitting the domain to 4 parts yields: - // subdomain: b2 b3 b4 b5, shifts: b0 b1. - // b2 b3 b4 b5 is indeed a circle domain, with a bigger jump. - // Traversing the domain in bit-reversed order, after we finish with b5, b4, b3, b2, - // we need to change b1 and then b0. This is the bit reverse of the shift b0 b1. - bit_reverse(&mut subdomain_shifts); - - let (span, mut extended_eval, subeval_polys) = accumulate_quotients_on_subdomain( - subdomain, - sample_batches, - random_coeff, - columns, - domain, - ); - - // Extend the evaluation to the full domain. - // TODO(spapini): Try to optimize out all these copies. - for (ci, &c) in subdomain_shifts.iter().enumerate() { - let subdomain = subdomain.shift(c); - - let twiddles = SimdBackend::precompute_twiddles(subdomain.half_coset); - #[allow(clippy::needless_range_loop)] - for i in 0..SECURE_EXTENSION_DEGREE { - // Sanity check. - let eval = subeval_polys[i].evaluate_with_twiddles(subdomain, &twiddles); - extended_eval.columns[i].data[(ci * eval.data.len())..((ci + 1) * eval.data.len())] - .copy_from_slice(&eval.data); - } - } - span.exit(); - - SecureEvaluation::new(domain, extended_eval) - } -} - -fn accumulate_quotients_on_subdomain( - subdomain: CircleDomain, - sample_batches: &[ColumnSampleBatch], - random_coeff: SecureField, - columns: &[&CircleEvaluation], - domain: CircleDomain, -) -> ( - span::EnteredSpan, - SecureColumnByCoords, - [crate::core::poly::circle::CirclePoly; 4], -) { - assert!(subdomain.log_size() >= LOG_N_LANES + 2); - let mut values = - unsafe { SecureColumnByCoords::::uninitialized(subdomain.size()) }; - let quotient_constants = quotient_constants(sample_batches, random_coeff, subdomain); - - let span = span!(Level::INFO, "Quotient accumulation").entered(); - for (quad_row, points) in CircleDomainBitRevIterator::new(subdomain) - .array_chunks::<4>() - .enumerate() - { - // TODO(spapini): Use optimized domain iteration. - let (y01, _) = points[0].y.deinterleave(points[1].y); - let (y23, _) = points[2].y.deinterleave(points[3].y); - let (spaced_ys, _) = y01.deinterleave(y23); - let row_accumulator = accumulate_row_quotients( - sample_batches, - columns, - "ient_constants, - quad_row, - spaced_ys, - ); - #[allow(clippy::needless_range_loop)] - for i in 0..4 { - unsafe { values.set_packed((quad_row << 2) + i, row_accumulator[i]) }; - } - } - span.exit(); - let span = span!(Level::INFO, "Quotient extension").entered(); - - // Extend the evaluation to the full domain. - let extended_eval = - unsafe { SecureColumnByCoords::::uninitialized(domain.size()) }; - - let mut i = 0; - let values = values.columns; - let twiddles = SimdBackend::precompute_twiddles(subdomain.half_coset); - let subeval_polys = values.map(|c| { - i += 1; - CircleEvaluation::::new(subdomain, c) - .interpolate_with_twiddles(&twiddles) - }); - (span, extended_eval, subeval_polys) -} - -/// Accumulates the quotients for 4 * N_LANES rows at a time. -/// spaced_ys - y values for N_LANES points in the domain, in jumps of 4. -pub fn accumulate_row_quotients( - sample_batches: &[ColumnSampleBatch], - columns: &[&CircleEvaluation], - quotient_constants: &QuotientConstants, - quad_row: usize, - spaced_ys: PackedBaseField, -) -> [PackedSecureField; 4] { - let mut row_accumulator = [PackedSecureField::zero(); 4]; - for (sample_batch, line_coeffs, batch_coeff, denominator_inverses) in izip!( - sample_batches, - "ient_constants.line_coeffs, - "ient_constants.batch_random_coeffs, - "ient_constants.denominator_inverses - ) { - let mut numerator = [PackedSecureField::zero(); 4]; - for ((column_index, _), (a, b, c)) in zip_eq(&sample_batch.columns_and_values, line_coeffs) - { - let column = &columns[*column_index]; - let cvalues: [_; 4] = std::array::from_fn(|i| { - PackedSecureField::broadcast(*c) * column.data[(quad_row << 2) + i] - }); - - // The numerator is the line equation: - // c * value - a * point.y - b; - // Note that a, b, c were already multilpied by random_coeff^i. - // See [column_line_coeffs()] for more details. - // This is why we only add here. - // 4 consecutive point in the domain in bit reversed order are: - // P, -P, P + H, -P + H. - // H being the half point (-1,0). The y values for these are - // P.y, -P.y, -P.y, P.y. - // We use this fact to save multiplications. - // spaced_ys are the y value in jumps of 4: - // P0.y, P1.y, P2.y, ... - let spaced_ay = PackedSecureField::broadcast(*a) * spaced_ys; - // t0:t1 = a*P0.y, -a*P0.y, a*P1.y, -a*P1.y, ... - let (t0, t1) = spaced_ay.interleave(-spaced_ay); - // t2:t3:t4:t5 = a*P0.y, -a*P0.y, -a*P0.y, a*P0.y, a*P1.y, -a*P1.y, ... - let (t2, t3) = t0.interleave(-t0); - let (t4, t5) = t1.interleave(-t1); - let ay = [t2, t3, t4, t5]; - for i in 0..4 { - numerator[i] += cvalues[i] - ay[i] - PackedSecureField::broadcast(*b); - } - } - - for i in 0..4 { - row_accumulator[i] = row_accumulator[i] * PackedSecureField::broadcast(*batch_coeff) - + numerator[i] * denominator_inverses.data[(quad_row << 2) + i]; - } - } - row_accumulator -} - -fn denominator_inverses( - sample_batches: &[ColumnSampleBatch], - domain: CircleDomain, -) -> Vec { - // We want a P to be on a line that passes through a point Pr + uPi in QM31^2, and its conjugate - // Pr - uPi. Thus, Pr - P is parallel to Pi. Or, (Pr - P).x * Pi.y - (Pr - P).y * Pi.x = 0. - let flat_denominators: CM31Column = sample_batches - .iter() - .flat_map(|sample_batch| { - // Extract Pr, Pi. - let prx = PackedCM31::broadcast(sample_batch.point.x.0); - let pry = PackedCM31::broadcast(sample_batch.point.y.0); - let pix = PackedCM31::broadcast(sample_batch.point.x.1); - let piy = PackedCM31::broadcast(sample_batch.point.y.1); - - // Line equation through pr +-u pi. - // (p-pr)* - CircleDomainBitRevIterator::new(domain) - .map(|points| (prx - points.x) * piy - (pry - points.y) * pix) - .collect_vec() - }) - .collect(); - - let mut flat_denominator_inverses = - unsafe { CM31Column::uninitialized(flat_denominators.len()) }; - FieldExpOps::batch_inverse( - &flat_denominators.data, - &mut flat_denominator_inverses.data[..], - ); - - flat_denominator_inverses - .data - .chunks(domain.size() / N_LANES) - .map(|denominator_inverses| denominator_inverses.iter().copied().collect()) - .collect() -} - -fn quotient_constants( - sample_batches: &[ColumnSampleBatch], - random_coeff: SecureField, - domain: CircleDomain, -) -> QuotientConstants { - let _span = span!(Level::INFO, "Quotient constants").entered(); - let line_coeffs = column_line_coeffs(sample_batches, random_coeff); - let batch_random_coeffs = batch_random_coeffs(sample_batches, random_coeff); - let denominator_inverses = denominator_inverses(sample_batches, domain); - QuotientConstants { - line_coeffs, - batch_random_coeffs, - denominator_inverses, - } -} - -#[cfg(test)] -mod tests { - use itertools::Itertools; - - use crate::core::backend::simd::column::BaseColumn; - use crate::core::backend::simd::SimdBackend; - use crate::core::backend::{Column, CpuBackend}; - use crate::core::circle::SECURE_FIELD_CIRCLE_GEN; - use crate::core::fields::m31::BaseField; - use crate::core::pcs::quotients::{ColumnSampleBatch, QuotientOps}; - use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; - use crate::core::poly::BitReversedOrder; - use crate::qm31; - - #[test] - fn test_accumulate_quotients() { - const LOG_SIZE: u32 = 8; - const LOG_BLOWUP_FACTOR: u32 = 1; - let small_domain = CanonicCoset::new(LOG_SIZE).circle_domain(); - let domain = CanonicCoset::new(LOG_SIZE + LOG_BLOWUP_FACTOR).circle_domain(); - let e0: BaseColumn = (0..small_domain.size()).map(BaseField::from).collect(); - let e1: BaseColumn = (0..small_domain.size()) - .map(|i| BaseField::from(2 * i)) - .collect(); - let polys = vec![ - CircleEvaluation::::new(small_domain, e0) - .interpolate(), - CircleEvaluation::::new(small_domain, e1) - .interpolate(), - ]; - let columns = vec![polys[0].evaluate(domain), polys[1].evaluate(domain)]; - let random_coeff = qm31!(1, 2, 3, 4); - let a = polys[0].eval_at_point(SECURE_FIELD_CIRCLE_GEN); - let b = polys[1].eval_at_point(SECURE_FIELD_CIRCLE_GEN); - let samples = vec![ColumnSampleBatch { - point: SECURE_FIELD_CIRCLE_GEN, - columns_and_values: vec![(0, a), (1, b)], - }]; - let cpu_columns = columns - .iter() - .map(|c| { - CircleEvaluation::::new( - c.domain, - c.values.to_cpu(), - ) - }) - .collect::>(); - let cpu_result = CpuBackend::accumulate_quotients( - domain, - &cpu_columns.iter().collect_vec(), - random_coeff, - &samples, - LOG_BLOWUP_FACTOR, - ) - .values - .to_vec(); - - let res = SimdBackend::accumulate_quotients( - domain, - &columns.iter().collect_vec(), - random_coeff, - &samples, - LOG_BLOWUP_FACTOR, - ) - .values - .to_vec(); - - assert_eq!(res, cpu_result); - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/utils.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/utils.rs deleted file mode 100644 index 87dfd22..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/simd/utils.rs +++ /dev/null @@ -1,52 +0,0 @@ -use std::simd::Swizzle; - -/// Used with [`Swizzle::concat_swizzle`] to interleave the even values of two vectors. -pub struct InterleaveEvens; - -impl Swizzle for InterleaveEvens { - const INDEX: [usize; N] = parity_interleave(false); -} - -/// Used with [`Swizzle::concat_swizzle`] to interleave the odd values of two vectors. -pub struct InterleaveOdds; - -impl Swizzle for InterleaveOdds { - const INDEX: [usize; N] = parity_interleave(true); -} - -const fn parity_interleave(odd: bool) -> [usize; N] { - let mut res = [0; N]; - let mut i = 0; - while i < N { - res[i] = (i % 2) * N + (i / 2) * 2 + if odd { 1 } else { 0 }; - i += 1; - } - res -} - -#[cfg(test)] -mod tests { - use std::simd::{u32x4, Swizzle}; - - use super::{InterleaveEvens, InterleaveOdds}; - - #[test] - fn interleave_evens() { - let lo = u32x4::from_array([0, 1, 2, 3]); - let hi = u32x4::from_array([4, 5, 6, 7]); - - let res = InterleaveEvens::concat_swizzle(lo, hi); - - assert_eq!(res, u32x4::from_array([0, 4, 2, 6])); - } - - #[test] - fn interleave_odds() { - let lo = u32x4::from_array([0, 1, 2, 3]); - let hi = u32x4::from_array([4, 5, 6, 7]); - - let res = InterleaveOdds::concat_swizzle(lo, hi); - - assert_eq!(res, u32x4::from_array([1, 5, 3, 7])); - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/very_packed_m31.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/very_packed_m31.rs deleted file mode 100644 index 2e344b8..0000000 --- a/Stwo_wrapper/crates/prover/src/core/backend/simd/very_packed_m31.rs +++ /dev/null @@ -1,222 +0,0 @@ -use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub}; - -use bytemuck::{Pod, Zeroable}; -use num_traits::{One, Zero}; - -use super::cm31::PackedCM31; -use super::m31::{PackedM31, N_LANES}; -use super::qm31::PackedQM31; -use crate::core::fields::cm31::CM31; -use crate::core::fields::m31::{pow2147483645, M31}; -use crate::core::fields::qm31::QM31; -use crate::core::fields::FieldExpOps; - -pub const LOG_N_VERY_PACKED_ELEMS: u32 = 1; -pub const N_VERY_PACKED_ELEMS: usize = 1 << LOG_N_VERY_PACKED_ELEMS; - -#[derive(Copy, Clone, Debug)] -#[repr(transparent)] -pub struct Vectorized(pub [A; N]); - -impl Vectorized { - pub fn from_fn(cb: F) -> Self - where - F: FnMut(usize) -> A, - { - Vectorized(std::array::from_fn(cb)) - } -} - -unsafe impl Zeroable for Vectorized { - fn zeroed() -> Self { - unsafe { core::mem::zeroed() } - } -} -unsafe impl Pod for Vectorized {} - -pub type VeryPackedM31 = Vectorized; -pub type VeryPackedCM31 = Vectorized; -pub type VeryPackedQM31 = Vectorized; -pub type VeryPackedBaseField = VeryPackedM31; -pub type VeryPackedSecureField = VeryPackedQM31; - -impl VeryPackedM31 { - pub fn broadcast(value: M31) -> Self { - Self::from_fn(|_| PackedM31::broadcast(value)) - } - - pub fn from_array(values: [M31; N_LANES * N_VERY_PACKED_ELEMS]) -> VeryPackedM31 { - Self::from_fn(|i| { - let start = i * N_LANES; - let end = start + N_LANES; - PackedM31::from_array(values[start..end].try_into().unwrap()) - }) - } - - pub fn to_array(&self) -> [M31; N_LANES * N_VERY_PACKED_ELEMS] { - // Safety: We are transmuting &[A; N_VERY_PACKED_ELEMS] into &[i32; N_LANES * - // N_VERY_PACKED_ELEMS] because we know that A contains [i32; N_LANES] and the - // memory layout is contiguous. - unsafe { - std::slice::from_raw_parts(self.0.as_ptr() as *const M31, N_LANES * N_VERY_PACKED_ELEMS) - .try_into() - .unwrap() - } - } -} - -impl VeryPackedCM31 { - pub fn broadcast(value: CM31) -> Self { - Self::from_fn(|_| PackedCM31::broadcast(value)) - } -} - -impl VeryPackedQM31 { - pub fn broadcast(value: QM31) -> Self { - Self::from_fn(|_| PackedQM31::broadcast(value)) - } - - pub fn from_very_packed_m31s([a, b, c, d]: [VeryPackedM31; 4]) -> Self { - Self::from_fn(|i| PackedQM31::from_packed_m31s([a.0[i], b.0[i], c.0[i], d.0[i]])) - } -} -impl From for VeryPackedM31 { - fn from(v: M31) -> Self { - Self::broadcast(v) - } -} - -impl From for VeryPackedQM31 { - fn from(value: VeryPackedM31) -> Self { - VeryPackedQM31::from_very_packed_m31s([ - value, - VeryPackedM31::zero(), - VeryPackedM31::zero(), - VeryPackedM31::zero(), - ]) - } -} - -impl From for VeryPackedQM31 { - fn from(value: QM31) -> Self { - VeryPackedQM31::broadcast(value) - } -} - -trait Scalar {} -impl Scalar for M31 {} -impl Scalar for CM31 {} -impl Scalar for QM31 {} -impl Scalar for PackedM31 {} -impl Scalar for PackedCM31 {} -impl Scalar for PackedQM31 {} - -impl + Copy, B: Copy, const N: usize> Add> for Vectorized { - type Output = Vectorized; - - fn add(self, other: Vectorized) -> Self::Output { - Vectorized::from_fn(|i| self.0[i] + other.0[i]) - } -} - -impl + Copy, B: Scalar + Copy, const N: usize> Add for Vectorized { - type Output = Vectorized; - - fn add(self, other: B) -> Self::Output { - Vectorized::from_fn(|i| self.0[i] + other) - } -} - -impl + Copy, B: Copy, const N: usize> Sub> for Vectorized { - type Output = Vectorized; - - fn sub(self, other: Vectorized) -> Self::Output { - Vectorized::from_fn(|i| self.0[i] - other.0[i]) - } -} - -impl + Copy, B: Scalar + Copy, const N: usize> Sub for Vectorized { - type Output = Vectorized; - - fn sub(self, other: B) -> Self::Output { - Vectorized::from_fn(|i| self.0[i] - other) - } -} - -impl + Copy, B: Copy, const N: usize> Mul> for Vectorized { - type Output = Vectorized; - - fn mul(self, other: Vectorized) -> Self::Output { - Vectorized::from_fn(|i| self.0[i] * other.0[i]) - } -} - -impl + Copy, B: Scalar + Copy, const N: usize> Mul for Vectorized { - type Output = Vectorized; - - fn mul(self, other: B) -> Self::Output { - Vectorized::from_fn(|i| self.0[i] * other) - } -} - -impl + Copy, B: Copy, const N: usize> AddAssign> - for Vectorized -{ - fn add_assign(&mut self, other: Vectorized) { - for i in 0..N { - self.0[i] += other.0[i]; - } - } -} - -impl + Copy, B: Scalar + Copy, const N: usize> AddAssign for Vectorized { - fn add_assign(&mut self, other: B) { - for i in 0..N { - self.0[i] += other; - } - } -} - -impl + Copy, B: Copy, const N: usize> MulAssign> - for Vectorized -{ - fn mul_assign(&mut self, other: Vectorized) { - for i in 0..N { - self.0[i] *= other.0[i]; - } - } -} - -impl Neg for Vectorized { - type Output = Vectorized; - - #[inline(always)] - fn neg(self) -> Self::Output { - Vectorized::from_fn(|i| self.0[i].neg()) - } -} - -impl Zero for Vectorized { - fn zero() -> Self { - Vectorized::from_fn(|_| A::zero()) - } - - fn is_zero(&self) -> bool { - self.0.iter().all(A::is_zero) - } -} - -impl One for Vectorized { - fn one() -> Self { - Vectorized::from_fn(|_| A::one()) - } -} - -impl FieldExpOps for Vectorized { - fn inverse(&self) -> Self { - Vectorized::from_fn(|i| { - assert!(!self.0[i].is_zero(), "0 has no inverse"); - pow2147483645(self.0[i]) - }) - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/channel/blake2s.rs b/Stwo_wrapper/crates/prover/src/core/channel/blake2s.rs deleted file mode 100644 index 9861862..0000000 --- a/Stwo_wrapper/crates/prover/src/core/channel/blake2s.rs +++ /dev/null @@ -1,186 +0,0 @@ -use std::iter; - -use super::{Channel, ChannelTime}; -use crate::core::fields::m31::{BaseField, N_BYTES_FELT, P}; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; -use crate::core::fields::IntoSlice; -use crate::core::vcs::blake2_hash::{Blake2sHash, Blake2sHasher}; -use crate::core::vcs::blake2s_ref::compress; - -pub const BLAKE_BYTES_PER_HASH: usize = 32; -pub const FELTS_PER_HASH: usize = 8; - -/// A channel that can be used to draw random elements from a [Blake2sHash] digest. -#[derive(Default, Clone)] -pub struct Blake2sChannel { - digest: Blake2sHash, - pub channel_time: ChannelTime, -} - -impl Blake2sChannel { - pub fn digest(&self) -> Blake2sHash { - self.digest - } - pub fn update_digest(&mut self, new_digest: Blake2sHash) { - self.digest = new_digest; - self.channel_time.inc_challenges(); - } - /// Generates a uniform random vector of BaseField elements. - fn draw_base_felts(&mut self) -> [BaseField; FELTS_PER_HASH] { - // Repeats hashing with an increasing counter until getting a good result. - // Retry probability for each round is ~ 2^(-28). - loop { - let u32s: [u32; FELTS_PER_HASH] = self - .draw_random_bytes() - .chunks_exact(N_BYTES_FELT) // 4 bytes per u32. - .map(|chunk| u32::from_le_bytes(chunk.try_into().unwrap())) - .collect::>() - .try_into() - .unwrap(); - - // Retry if not all the u32 are in the range [0, 2P). - if u32s.iter().all(|x| *x < 2 * P) { - return u32s - .into_iter() - .map(|x| BaseField::reduce(x as u64)) - .collect::>() - .try_into() - .unwrap(); - } - } - } -} - -impl Channel for Blake2sChannel { - const BYTES_PER_HASH: usize = BLAKE_BYTES_PER_HASH; - - fn trailing_zeros(&self) -> u32 { - u128::from_le_bytes(std::array::from_fn(|i| self.digest.0[i])).trailing_zeros() - } - - fn mix_felts(&mut self, felts: &[SecureField]) { - let mut hasher = Blake2sHasher::new(); - hasher.update(self.digest.as_ref()); - hasher.update(IntoSlice::::into_slice(felts)); - - self.update_digest(hasher.finalize()); - } - - fn mix_nonce(&mut self, nonce: u64) { - let digest: [u32; 8] = unsafe { std::mem::transmute(self.digest) }; - let mut msg = [0; 16]; - msg[0] = nonce as u32; - msg[1] = (nonce >> 32) as u32; - let res = compress(std::array::from_fn(|i| digest[i]), msg, 0, 0, 0, 0); - - self.update_digest(unsafe { std::mem::transmute(res) }); - } - - fn draw_felt(&mut self) -> SecureField { - let felts: [BaseField; FELTS_PER_HASH] = self.draw_base_felts(); - SecureField::from_m31_array(felts[..SECURE_EXTENSION_DEGREE].try_into().unwrap()) - } - - fn draw_felts(&mut self, n_felts: usize) -> Vec { - let mut felts = iter::from_fn(|| Some(self.draw_base_felts())).flatten(); - let secure_felts = iter::from_fn(|| { - Some(SecureField::from_m31_array([ - felts.next()?, - felts.next()?, - felts.next()?, - felts.next()?, - ])) - }); - secure_felts.take(n_felts).collect() - } - - fn draw_random_bytes(&mut self) -> Vec { - let mut hash_input = self.digest.as_ref().to_vec(); - - // Pad the counter to 32 bytes. - let mut padded_counter = [0; BLAKE_BYTES_PER_HASH]; - let counter_bytes = self.channel_time.n_sent.to_le_bytes(); - padded_counter[0..counter_bytes.len()].copy_from_slice(&counter_bytes); - - hash_input.extend_from_slice(&padded_counter); - - // TODO(spapini): Are we worried about this drawing hash colliding with mix_digest? - - self.channel_time.inc_sent(); - Blake2sHasher::hash(&hash_input).into() - } -} - -#[cfg(test)] -mod tests { - use std::collections::BTreeSet; - - use crate::core::channel::blake2s::Blake2sChannel; - use crate::core::channel::Channel; - use crate::core::fields::qm31::SecureField; - use crate::m31; - - #[test] - fn test_channel_time() { - let mut channel = Blake2sChannel::default(); - - assert_eq!(channel.channel_time.n_challenges, 0); - assert_eq!(channel.channel_time.n_sent, 0); - - channel.draw_random_bytes(); - assert_eq!(channel.channel_time.n_challenges, 0); - assert_eq!(channel.channel_time.n_sent, 1); - - channel.draw_felts(9); - assert_eq!(channel.channel_time.n_challenges, 0); - assert_eq!(channel.channel_time.n_sent, 6); - } - - #[test] - fn test_draw_random_bytes() { - let mut channel = Blake2sChannel::default(); - - let first_random_bytes = channel.draw_random_bytes(); - - // Assert that next random bytes are different. - assert_ne!(first_random_bytes, channel.draw_random_bytes()); - } - - #[test] - pub fn test_draw_felt() { - let mut channel = Blake2sChannel::default(); - - let first_random_felt = channel.draw_felt(); - - // Assert that next random felt is different. - assert_ne!(first_random_felt, channel.draw_felt()); - } - - #[test] - pub fn test_draw_felts() { - let mut channel = Blake2sChannel::default(); - - let mut random_felts = channel.draw_felts(5); - random_felts.extend(channel.draw_felts(4)); - - // Assert that all the random felts are unique. - assert_eq!( - random_felts.len(), - random_felts.iter().collect::>().len() - ); - } - - #[test] - pub fn test_mix_felts() { - let mut channel = Blake2sChannel::default(); - let initial_digest = channel.digest; - let felts: Vec = (0..2) - .map(|i| SecureField::from(m31!(i + 1923782))) - .collect(); - - channel.mix_felts(felts.as_slice()); - - assert_ne!(initial_digest, channel.digest); - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/channel/mod.rs b/Stwo_wrapper/crates/prover/src/core/channel/mod.rs deleted file mode 100644 index 3d85b8d..0000000 --- a/Stwo_wrapper/crates/prover/src/core/channel/mod.rs +++ /dev/null @@ -1,57 +0,0 @@ -use super::fields::qm31::SecureField; -use super::vcs::ops::MerkleHasher; - -#[cfg(not(target_arch = "wasm32"))] -mod poseidon252; -#[cfg(not(target_arch = "wasm32"))] -pub use poseidon252::Poseidon252Channel; - -mod blake2s; -pub use blake2s::Blake2sChannel; - -#[cfg(not(target_arch = "wasm32"))] -mod poseidon_bls; -#[cfg(not(target_arch = "wasm32"))] -pub use poseidon_bls::PoseidonBLSChannel; - -pub const EXTENSION_FELTS_PER_HASH: usize = 2; - -#[derive(Clone, Default)] -pub struct ChannelTime { - pub n_challenges: usize, - n_sent: usize, -} - -impl ChannelTime { - fn inc_sent(&mut self) { - self.n_sent += 1; - } - - fn inc_challenges(&mut self) { - self.n_challenges += 1; - self.n_sent = 0; - } -} - -pub trait Channel: Default + Clone { - const BYTES_PER_HASH: usize; - - fn trailing_zeros(&self) -> u32; - - // Mix functions. - fn mix_felts(&mut self, felts: &[SecureField]); - fn mix_nonce(&mut self, nonce: u64); - - // Draw functions. - fn draw_felt(&mut self) -> SecureField; - /// Generates a uniform random vector of SecureField elements. - fn draw_felts(&mut self, n_felts: usize) -> Vec; - /// Returns a vector of random bytes of length `BYTES_PER_HASH`. - fn draw_random_bytes(&mut self) -> Vec; -} - -pub trait MerkleChannel: Default { - type C: Channel; - type H: MerkleHasher; - fn mix_root(channel: &mut Self::C, root: ::Hash); -} diff --git a/Stwo_wrapper/crates/prover/src/core/channel/poseidon252.rs b/Stwo_wrapper/crates/prover/src/core/channel/poseidon252.rs deleted file mode 100644 index 195d1fc..0000000 --- a/Stwo_wrapper/crates/prover/src/core/channel/poseidon252.rs +++ /dev/null @@ -1,190 +0,0 @@ -use std::iter; - -use starknet_crypto::{poseidon_hash, poseidon_hash_many}; -use starknet_ff::FieldElement as FieldElement252; - -use super::{Channel, ChannelTime}; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; - -pub const BYTES_PER_FELT252: usize = 31; -pub const FELTS_PER_HASH: usize = 8; - -/// A channel that can be used to draw random elements from a Poseidon252 hash. -#[derive(Clone, Default)] -pub struct Poseidon252Channel { - digest: FieldElement252, - pub channel_time: ChannelTime, -} - -impl Poseidon252Channel { - pub fn digest(&self) -> FieldElement252 { - self.digest - } - pub fn update_digest(&mut self, new_digest: FieldElement252) { - self.digest = new_digest; - self.channel_time.inc_challenges(); - } - fn draw_felt252(&mut self) -> FieldElement252 { - let res = poseidon_hash(self.digest, self.channel_time.n_sent.into()); - self.channel_time.inc_sent(); - res - } - - // TODO(spapini): Understand if we really need uniformity here. - /// Generates a close-to uniform random vector of BaseField elements. - fn draw_base_felts(&mut self) -> [BaseField; 8] { - let shift = (1u64 << 31).into(); - - let mut cur = self.draw_felt252(); - let u32s: [u32; 8] = std::array::from_fn(|_| { - let next = cur.floor_div(shift); - let res = cur - next * shift; - cur = next; - res.try_into().unwrap() - }); - - u32s.into_iter() - .map(|x| BaseField::reduce(x as u64)) - .collect::>() - .try_into() - .unwrap() - } -} - -impl Channel for Poseidon252Channel { - const BYTES_PER_HASH: usize = BYTES_PER_FELT252; - - fn trailing_zeros(&self) -> u32 { - let bytes = self.digest.to_bytes_be(); - u128::from_le_bytes(std::array::from_fn(|i| bytes[i])).trailing_zeros() - } - - // TODO(spapini): Optimize. - fn mix_felts(&mut self, felts: &[SecureField]) { - let shift = (1u64 << 31).into(); - let mut res = Vec::with_capacity(felts.len() / 2 + 2); - res.push(self.digest); - for chunk in felts.chunks(2) { - res.push( - chunk - .iter() - .flat_map(|x| x.to_m31_array()) - .fold(FieldElement252::default(), |cur, y| { - cur * shift + y.0.into() - }), - ); - } - - // TODO(spapini): do we need length padding? - self.update_digest(poseidon_hash_many(&res)); - } - - fn mix_nonce(&mut self, nonce: u64) { - self.update_digest(poseidon_hash(self.digest, nonce.into())); - } - - fn draw_felt(&mut self) -> SecureField { - let felts: [BaseField; FELTS_PER_HASH] = self.draw_base_felts(); - SecureField::from_m31_array(felts[..SECURE_EXTENSION_DEGREE].try_into().unwrap()) - } - - fn draw_felts(&mut self, n_felts: usize) -> Vec { - let mut felts = iter::from_fn(|| Some(self.draw_base_felts())).flatten(); - let secure_felts = iter::from_fn(|| { - Some(SecureField::from_m31_array([ - felts.next()?, - felts.next()?, - felts.next()?, - felts.next()?, - ])) - }); - secure_felts.take(n_felts).collect() - } - - fn draw_random_bytes(&mut self) -> Vec { - let shift = (1u64 << 8).into(); - let mut cur = self.draw_felt252(); - let bytes: [u8; 31] = std::array::from_fn(|_| { - let next = cur.floor_div(shift); - let res = cur - next * shift; - cur = next; - res.try_into().unwrap() - }); - bytes.to_vec() - } -} - -#[cfg(test)] -mod tests { - use std::collections::BTreeSet; - - use crate::core::channel::poseidon252::Poseidon252Channel; - use crate::core::channel::Channel; - use crate::core::fields::qm31::SecureField; - use crate::m31; - - #[test] - fn test_channel_time() { - let mut channel = Poseidon252Channel::default(); - - assert_eq!(channel.channel_time.n_challenges, 0); - assert_eq!(channel.channel_time.n_sent, 0); - - channel.draw_random_bytes(); - assert_eq!(channel.channel_time.n_challenges, 0); - assert_eq!(channel.channel_time.n_sent, 1); - - channel.draw_felts(9); - assert_eq!(channel.channel_time.n_challenges, 0); - assert_eq!(channel.channel_time.n_sent, 6); - } - - #[test] - fn test_draw_random_bytes() { - let mut channel = Poseidon252Channel::default(); - - let first_random_bytes = channel.draw_random_bytes(); - - // Assert that next random bytes are different. - assert_ne!(first_random_bytes, channel.draw_random_bytes()); - } - - #[test] - pub fn test_draw_felt() { - let mut channel = Poseidon252Channel::default(); - - let first_random_felt = channel.draw_felt(); - - // Assert that next random felt is different. - assert_ne!(first_random_felt, channel.draw_felt()); - } - - #[test] - pub fn test_draw_felts() { - let mut channel = Poseidon252Channel::default(); - - let mut random_felts = channel.draw_felts(5); - random_felts.extend(channel.draw_felts(4)); - - // Assert that all the random felts are unique. - assert_eq!( - random_felts.len(), - random_felts.iter().collect::>().len() - ); - } - - #[test] - pub fn test_mix_felts() { - let mut channel = Poseidon252Channel::default(); - let initial_digest = channel.digest; - let felts: Vec = (0..2) - .map(|i| SecureField::from(m31!(i + 1923782))) - .collect(); - - channel.mix_felts(felts.as_slice()); - - assert_ne!(initial_digest, channel.digest); - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/channel/poseidon_bls.rs b/Stwo_wrapper/crates/prover/src/core/channel/poseidon_bls.rs deleted file mode 100644 index 65fdce5..0000000 --- a/Stwo_wrapper/crates/prover/src/core/channel/poseidon_bls.rs +++ /dev/null @@ -1,590 +0,0 @@ -use std::iter; - -use ark_bls12_381::Fr as BlsFr; -use ark_ff::{BigInteger, Field, PrimeField}; -use crypto_bigint::{Encoding, NonZero, U256}; - -use super::{Channel, ChannelTime}; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; - -pub const BYTES_PER_FELT252: usize = 32; -pub const FELTS_PER_HASH: usize = 8; - -//Optimize constant to be real constants (no conversion) and merge duplicated code in VCS poseidono -fn poseidon_comp_consts(idx: usize) -> BlsFr { - match idx { - 0 => BlsFr::from_be_bytes_mod_order(&[ - 111, 0, 122, 85, 17, 86, 179, 164, 73, 228, 73, 54, 183, 192, 147, 100, 74, 14, 211, - 63, 51, 234, 204, 198, 40, 233, 66, 232, 54, 193, 168, 117, - ]), - 1 => BlsFr::from_be_bytes_mod_order(&[ - 54, 13, 116, 112, 97, 30, 71, 61, 53, 63, 98, 143, 118, 209, 16, 243, 78, 113, 22, 47, - 49, 0, 59, 112, 87, 83, 140, 37, 150, 66, 99, 3, - ]), - 2 => BlsFr::from_be_bytes_mod_order(&[ - 75, 95, 236, 58, 160, 115, 223, 68, 1, 144, 145, 240, 7, 164, 76, 169, 150, 72, 73, - 101, 247, 3, 109, 206, 62, 157, 9, 119, 237, 205, 192, 246, - ]), - 3 => BlsFr::from_be_bytes_mod_order(&[ - 103, 207, 24, 104, 175, 99, 150, 192, 184, 76, 206, 113, 94, 83, 159, 132, 158, 6, 205, - 28, 56, 58, 197, 176, 97, 0, 199, 107, 204, 151, 58, 17, - ]), - 4 => BlsFr::from_be_bytes_mod_order(&[ - 85, 93, 180, 209, 220, 237, 129, 159, 93, 61, 231, 15, 222, 131, 241, 199, 211, 232, - 201, 137, 104, 229, 22, 162, 58, 119, 26, 92, 156, 130, 87, 170, - ]), - 5 => BlsFr::from_be_bytes_mod_order(&[ - 43, 171, 148, 215, 174, 34, 45, 19, 93, 195, 198, 197, 254, 191, 170, 49, 73, 8, 172, - 47, 18, 235, 224, 111, 189, 183, 66, 19, 191, 99, 24, 139, - ]), - 6 => BlsFr::from_be_bytes_mod_order(&[ - 102, 244, 75, 229, 41, 102, 130, 196, 250, 120, 130, 121, 157, 109, 208, 73, 182, 215, - 210, 201, 80, 204, 249, 140, 242, 229, 13, 109, 30, 187, 119, 194, - ]), - 7 => BlsFr::from_be_bytes_mod_order(&[ - 21, 12, 147, 254, 246, 82, 251, 28, 43, 240, 62, 26, 41, 170, 135, 31, 239, 119, 231, - 215, 54, 118, 108, 93, 9, 57, 217, 39, 83, 204, 93, 200, - ]), - 8 => BlsFr::from_be_bytes_mod_order(&[ - 50, 112, 102, 30, 104, 146, 139, 58, 149, 93, 85, 219, 86, 220, 87, 193, 3, 204, 10, - 96, 20, 30, 137, 78, 20, 37, 157, 206, 83, 119, 130, 178, - ]), - 9 => BlsFr::from_be_bytes_mod_order(&[ - 7, 63, 17, 111, 4, 18, 46, 37, 160, 183, 175, 228, 226, 5, 114, 153, 180, 7, 195, 112, - 242, 181, 161, 204, 206, 159, 185, 255, 195, 69, 175, 179, - ]), - 10 => BlsFr::from_be_bytes_mod_order(&[ - 64, 159, 218, 34, 85, 140, 254, 77, 61, 216, 220, 226, 79, 105, 231, 111, 140, 42, 174, - 177, 221, 15, 9, 214, 94, 101, 76, 113, 243, 42, 162, 63, - ]), - 11 => BlsFr::from_be_bytes_mod_order(&[ - 42, 50, 236, 92, 78, 229, 177, 131, 122, 255, 208, 156, 31, 83, 245, 253, 85, 201, 205, - 32, 97, 174, 147, 202, 142, 186, 215, 111, 199, 21, 84, 216, - ]), - 12 => BlsFr::from_be_bytes_mod_order(&[ - 88, 72, 235, 235, 89, 35, 233, 37, 85, 183, 18, 79, 255, 186, 93, 107, 213, 113, 198, - 249, 132, 25, 94, 185, 207, 211, 163, 232, 235, 85, 177, 212, - ]), - 13 => BlsFr::from_be_bytes_mod_order(&[ - 39, 3, 38, 238, 3, 157, 241, 158, 101, 30, 44, 252, 116, 6, 40, 202, 99, 77, 36, 252, - 110, 37, 89, 242, 45, 140, 203, 226, 146, 239, 238, 173, - ]), - 14 => BlsFr::from_be_bytes_mod_order(&[ - 39, 198, 100, 42, 198, 51, 188, 102, 220, 16, 15, 231, 252, 250, 84, 145, 138, 248, - 149, 188, 224, 18, 241, 130, 160, 104, 252, 55, 193, 130, 226, 116, - ]), - 15 => BlsFr::from_be_bytes_mod_order(&[ - 27, 223, 216, 176, 20, 1, 199, 10, 210, 127, 87, 57, 105, 137, 18, 157, 113, 14, 31, - 182, 171, 151, 106, 69, 156, 161, 134, 130, 226, 109, 127, 249, - ]), - 16 => BlsFr::from_be_bytes_mod_order(&[ - 73, 27, 155, 166, 152, 59, 207, 159, 5, 254, 71, 148, 173, 180, 74, 48, 135, 155, 248, - 40, 150, 98, 225, 245, 125, 144, 246, 114, 65, 78, 138, 74, - ]), - 17 => BlsFr::from_be_bytes_mod_order(&[ - 22, 42, 20, 198, 47, 154, 137, 184, 20, 185, 214, 169, 200, 77, 214, 120, 244, 246, - 251, 63, 144, 84, 211, 115, 200, 50, 216, 36, 38, 26, 53, 234, - ]), - 18 => BlsFr::from_be_bytes_mod_order(&[ - 45, 25, 62, 15, 118, 222, 88, 107, 42, 246, 247, 158, 49, 39, 254, 234, 172, 10, 31, - 199, 30, 44, 240, 192, 247, 152, 36, 102, 123, 91, 107, 236, - ]), - 19 => BlsFr::from_be_bytes_mod_order(&[ - 70, 239, 216, 169, 162, 98, 214, 216, 253, 201, 202, 92, 4, 176, 152, 47, 36, 221, 204, - 110, 152, 99, 136, 90, 106, 115, 42, 57, 6, 160, 123, 149, - ]), - 20 => BlsFr::from_be_bytes_mod_order(&[ - 80, 151, 23, 224, 194, 0, 227, 201, 45, 141, 202, 41, 115, 179, 219, 69, 240, 120, 130, - 148, 53, 26, 208, 122, 231, 92, 187, 120, 6, 147, 167, 152, - ]), - 21 => BlsFr::from_be_bytes_mod_order(&[ - 114, 153, 178, 132, 100, 168, 201, 79, 185, 212, 223, 97, 56, 15, 57, 192, 220, 169, - 194, 192, 20, 17, 135, 137, 226, 39, 37, 40, 32, 240, 27, 252, - ]), - 22 => BlsFr::from_be_bytes_mod_order(&[ - 4, 76, 163, 204, 74, 133, 215, 59, 129, 105, 110, 241, 16, 78, 103, 79, 79, 239, 248, - 41, 132, 153, 15, 248, 93, 11, 245, 141, 200, 164, 170, 148, - ]), - 23 => BlsFr::from_be_bytes_mod_order(&[ - 28, 186, 242, 179, 113, 218, 198, 168, 29, 4, 83, 65, 109, 62, 35, 92, 184, 217, 226, - 212, 243, 20, 244, 111, 97, 152, 120, 95, 12, 214, 185, 175, - ]), - 24 => BlsFr::from_be_bytes_mod_order(&[ - 29, 91, 39, 119, 105, 44, 32, 91, 14, 108, 73, 208, 97, 182, 181, 244, 41, 60, 74, 176, - 56, 253, 187, 220, 52, 62, 7, 97, 15, 63, 237, 229, - ]), - 25 => BlsFr::from_be_bytes_mod_order(&[ - 86, 174, 124, 122, 82, 147, 189, 194, 62, 133, 225, 105, 140, 129, 199, 127, 138, 216, - 140, 75, 51, 165, 120, 4, 55, 173, 4, 124, 110, 219, 89, 186, - ]), - 26 => BlsFr::from_be_bytes_mod_order(&[ - 46, 155, 219, 186, 61, 211, 75, 255, 170, 48, 83, 91, 221, 116, 154, 126, 6, 169, 173, - 176, 193, 230, 249, 98, 246, 14, 151, 27, 141, 115, 176, 79, - ]), - 27 => BlsFr::from_be_bytes_mod_order(&[ - 45, 225, 24, 134, 177, 128, 17, 202, 139, 213, 186, 227, 105, 105, 41, 159, 222, 64, - 251, 226, 109, 4, 123, 5, 3, 90, 19, 102, 31, 34, 65, 139, - ]), - 28 => BlsFr::from_be_bytes_mod_order(&[ - 46, 7, 222, 23, 128, 184, 167, 13, 13, 91, 74, 63, 24, 65, 220, 216, 42, 185, 57, 92, - 68, 155, 233, 71, 188, 153, 136, 132, 186, 150, 167, 33, - ]), - 29 => BlsFr::from_be_bytes_mod_order(&[ - 15, 105, 241, 133, 77, 32, 202, 12, 187, 219, 99, 219, 213, 45, 173, 22, 37, 4, 64, - 169, 157, 107, 138, 243, 130, 94, 76, 43, 183, 73, 37, 202, - ]), - 30 => BlsFr::from_be_bytes_mod_order(&[ - 93, 201, 135, 49, 142, 110, 89, 193, 175, 184, 123, 101, 93, 213, 140, 193, 210, 46, - 81, 58, 5, 131, 140, 212, 88, 93, 4, 177, 53, 185, 87, 202, - ]), - 31 => BlsFr::from_be_bytes_mod_order(&[ - 72, 183, 37, 117, 133, 113, 201, 223, 108, 1, 220, 99, 154, 133, 240, 114, 151, 105, - 107, 27, 182, 120, 99, 58, 41, 220, 145, 222, 149, 239, 83, 246, - ]), - 32 => BlsFr::from_be_bytes_mod_order(&[ - 94, 86, 94, 8, 192, 130, 16, 153, 37, 107, 86, 73, 14, 174, 225, 213, 115, 175, 209, - 11, 182, 209, 125, 19, 202, 78, 92, 97, 27, 42, 55, 24, - ]), - 33 => BlsFr::from_be_bytes_mod_order(&[ - 46, 177, 178, 84, 23, 254, 23, 103, 13, 19, 93, 198, 57, 251, 9, 164, 108, 229, 17, 53, - 7, 249, 109, 233, 129, 108, 5, 148, 34, 220, 112, 94, - ]), - 34 => BlsFr::from_be_bytes_mod_order(&[ - 17, 92, 208, 160, 100, 60, 251, 152, 140, 36, 203, 68, 195, 250, 180, 138, 255, 54, - 198, 97, 210, 108, 196, 45, 184, 177, 189, 244, 149, 59, 216, 44, - ]), - 35 => BlsFr::from_be_bytes_mod_order(&[ - 38, 202, 41, 63, 123, 44, 70, 45, 6, 109, 115, 120, 185, 153, 134, 139, 187, 87, 221, - 241, 78, 15, 149, 138, 222, 128, 22, 18, 49, 29, 4, 205, - ]), - 36 => BlsFr::from_be_bytes_mod_order(&[ - 65, 71, 64, 13, 142, 26, 172, 207, 49, 26, 107, 91, 118, 32, 17, 171, 62, 69, 50, 110, - 77, 75, 157, 226, 105, 146, 129, 107, 153, 197, 40, 172, - ]), - 37 => BlsFr::from_be_bytes_mod_order(&[ - 107, 13, 183, 220, 204, 75, 161, 178, 104, 246, 189, 204, 77, 55, 40, 72, 212, 167, 41, - 118, 194, 104, 234, 48, 81, 154, 47, 115, 230, 219, 77, 85, - ]), - 38 => BlsFr::from_be_bytes_mod_order(&[ - 23, 191, 27, 147, 196, 199, 224, 26, 42, 131, 10, 161, 98, 65, 44, 217, 15, 22, 11, - 249, 247, 30, 150, 127, 245, 32, 157, 20, 178, 72, 32, 202, - ]), - 39 => BlsFr::from_be_bytes_mod_order(&[ - 75, 67, 28, 217, 239, 237, 188, 148, 207, 30, 202, 111, 158, 156, 24, 57, 208, 230, - 106, 139, 255, 168, 200, 70, 76, 172, 129, 163, 157, 60, 248, 241, - ]), - 40 => BlsFr::from_be_bytes_mod_order(&[ - 53, 180, 26, 122, 196, 243, 197, 113, 162, 79, 132, 86, 54, 156, 133, 223, 224, 60, 3, - 84, 189, 140, 253, 56, 5, 200, 111, 46, 125, 194, 147, 197, - ]), - 41 => BlsFr::from_be_bytes_mod_order(&[ - 59, 20, 128, 8, 5, 35, 196, 57, 67, 89, 39, 153, 72, 73, 190, 169, 100, 225, 77, 59, - 235, 45, 221, 222, 114, 172, 21, 106, 244, 53, 208, 158, - ]), - 42 => BlsFr::from_be_bytes_mod_order(&[ - 44, 198, 129, 0, 49, 220, 27, 13, 73, 80, 133, 109, 201, 7, 213, 117, 8, 226, 134, 68, - 42, 45, 62, 178, 39, 22, 24, 216, 116, 177, 76, 109, - ]), - 43 => BlsFr::from_be_bytes_mod_order(&[ - 111, 65, 65, 200, 64, 28, 90, 57, 91, 166, 121, 14, 253, 113, 199, 12, 4, 175, 234, 6, - 195, 201, 40, 38, 188, 171, 221, 92, 181, 71, 125, 81, - ]), - 44 => BlsFr::from_be_bytes_mod_order(&[ - 37, 189, 187, 237, 161, 189, 232, 193, 5, 150, 24, 226, 175, 210, 239, 153, 158, 81, - 122, 169, 59, 120, 52, 29, 145, 243, 24, 192, 159, 12, 181, 102, - ]), - 45 => BlsFr::from_be_bytes_mod_order(&[ - 57, 42, 74, 135, 88, 224, 110, 232, 185, 95, 51, 194, 93, 222, 138, 192, 42, 94, 208, - 162, 123, 97, 146, 108, 198, 49, 52, 135, 7, 63, 127, 123, - ]), - 46 => BlsFr::from_be_bytes_mod_order(&[ - 39, 42, 85, 135, 138, 8, 68, 43, 154, 166, 17, 31, 77, 224, 9, 72, 94, 106, 111, 209, - 93, 184, 147, 101, 231, 187, 206, 240, 46, 181, 134, 108, - ]), - 47 => BlsFr::from_be_bytes_mod_order(&[ - 99, 30, 193, 214, 210, 141, 217, 232, 36, 238, 137, 163, 7, 48, 174, 247, 171, 70, 58, - 207, 201, 209, 132, 179, 85, 170, 5, 253, 105, 56, 234, 181, - ]), - 48 => BlsFr::from_be_bytes_mod_order(&[ - 78, 182, 253, 161, 15, 208, 251, 222, 2, 199, 68, 155, 251, 221, 195, 91, 205, 130, 37, - 231, 229, 195, 131, 58, 8, 24, 161, 0, 64, 157, 198, 242, - ]), - 49 => BlsFr::from_be_bytes_mod_order(&[ - 45, 91, 48, 139, 12, 240, 44, 223, 239, 161, 60, 78, 96, 226, 98, 57, 166, 235, 186, 1, - 22, 148, 221, 18, 155, 146, 91, 60, 91, 33, 224, 226, - ]), - 50 => BlsFr::from_be_bytes_mod_order(&[ - 22, 84, 159, 198, 175, 47, 59, 114, 221, 93, 41, 61, 114, 226, 229, 242, 68, 223, 244, - 47, 24, 180, 108, 86, 239, 56, 197, 124, 49, 22, 115, 172, - ]), - 51 => BlsFr::from_be_bytes_mod_order(&[ - 66, 51, 38, 119, 255, 53, 156, 94, 141, 184, 54, 217, 245, 251, 84, 130, 46, 57, 189, - 94, 34, 52, 11, 185, 186, 151, 91, 161, 169, 43, 227, 130, - ]), - 52 => BlsFr::from_be_bytes_mod_order(&[ - 73, 215, 210, 192, 180, 73, 229, 23, 155, 197, 204, 195, 180, 76, 96, 117, 217, 132, - 155, 86, 16, 70, 95, 9, 234, 114, 93, 220, 151, 114, 58, 148, - ]), - 53 => BlsFr::from_be_bytes_mod_order(&[ - 100, 194, 15, 185, 13, 122, 0, 56, 49, 117, 124, 196, 198, 34, 111, 110, 73, 133, 252, - 158, 203, 65, 107, 159, 104, 76, 160, 53, 29, 150, 121, 4, - ]), - 54 => BlsFr::from_be_bytes_mod_order(&[ - 89, 207, 244, 13, 232, 59, 82, 180, 27, 196, 67, 215, 151, 149, 16, 215, 113, 201, 64, - 185, 117, 140, 168, 32, 254, 115, 181, 200, 213, 88, 9, 52, - ]), - 55 => BlsFr::from_be_bytes_mod_order(&[ - 83, 219, 39, 49, 115, 12, 57, 176, 78, 221, 135, 95, 227, 183, 200, 130, 128, 130, 133, - 205, 188, 98, 29, 122, 244, 248, 13, 213, 62, 187, 113, 176, - ]), - 56 => BlsFr::from_be_bytes_mod_order(&[ - 27, 16, 187, 122, 130, 175, 206, 57, 250, 105, 195, 162, 173, 82, 247, 109, 118, 57, - 130, 101, 52, 66, 3, 17, 155, 113, 38, 217, 180, 104, 96, 223, - ]), - 57 => BlsFr::from_be_bytes_mod_order(&[ - 86, 27, 96, 18, 214, 102, 191, 225, 121, 196, 221, 127, 132, 205, 209, 83, 21, 150, - 211, 170, 199, 197, 112, 12, 235, 49, 159, 145, 4, 106, 99, 201, - ]), - 58 => BlsFr::from_be_bytes_mod_order(&[ - 15, 30, 117, 5, 235, 217, 29, 47, 199, 156, 45, 247, 220, 152, 163, 190, 209, 179, 105, - 104, 186, 4, 5, 192, 144, 210, 127, 106, 0, 183, 223, 200, - ]), - 59 => BlsFr::from_be_bytes_mod_order(&[ - 47, 49, 63, 175, 13, 63, 97, 135, 83, 122, 116, 151, 163, 180, 63, 70, 121, 127, 214, - 227, 241, 142, 177, 202, 255, 69, 119, 86, 184, 25, 187, 32, - ]), - 60 => BlsFr::from_be_bytes_mod_order(&[ - 58, 92, 187, 109, 228, 80, 180, 129, 250, 60, 166, 28, 14, 209, 91, 197, 92, 173, 17, - 235, 240, 247, 206, 184, 240, 188, 62, 115, 46, 203, 38, 246, - ]), - 61 => BlsFr::from_be_bytes_mod_order(&[ - 104, 29, 147, 65, 27, 248, 206, 99, 246, 113, 106, 239, 189, 14, 36, 80, 100, 84, 192, - 52, 142, 227, 143, 171, 235, 38, 71, 2, 113, 76, 207, 148, - ]), - 62 => BlsFr::from_be_bytes_mod_order(&[ - 81, 120, 233, 64, 245, 0, 4, 49, 38, 70, 180, 54, 114, 127, 14, 128, 167, 184, 242, - 233, 238, 31, 220, 103, 124, 72, 49, 167, 103, 39, 119, 251, - ]), - 63 => BlsFr::from_be_bytes_mod_order(&[ - 61, 171, 84, 188, 155, 239, 104, 141, 217, 32, 134, 226, 83, 180, 57, 214, 81, 186, - 166, 226, 15, 137, 43, 98, 134, 85, 39, 203, 202, 145, 89, 130, - ]), - 64 => BlsFr::from_be_bytes_mod_order(&[ - 75, 60, 231, 83, 17, 33, 143, 154, 233, 5, 248, 78, 170, 91, 43, 56, 24, 68, 139, 191, - 57, 114, 225, 170, 214, 157, 227, 33, 0, 144, 21, 208, - ]), - 65 => BlsFr::from_be_bytes_mod_order(&[ - 6, 219, 251, 66, 185, 121, 136, 77, 226, 128, 211, 22, 112, 18, 63, 116, 76, 36, 179, - 59, 65, 15, 239, 212, 54, 128, 69, 172, 242, 183, 26, 227, - ]), - 66 => BlsFr::from_be_bytes_mod_order(&[ - 6, 141, 107, 70, 8, 170, 232, 16, 198, 240, 57, 234, 25, 115, 166, 62, 184, 210, 222, - 114, 227, 210, 201, 236, 167, 252, 50, 210, 47, 24, 185, 211, - ]), - 67 => BlsFr::from_be_bytes_mod_order(&[ - 76, 92, 37, 69, 137, 169, 42, 54, 8, 74, 87, 211, 177, 217, 100, 39, 138, 204, 126, 79, - 232, 246, 159, 41, 85, 149, 79, 39, 167, 156, 235, 239, - ]), - 68 => BlsFr::from_be_bytes_mod_order(&[ - 108, 186, 197, 225, 112, 9, 132, 235, 195, 45, 161, 91, 75, 185, 104, 63, 170, 186, - 181, 95, 103, 204, 196, 247, 29, 149, 96, 179, 71, 90, 119, 235, - ]), - 69 => BlsFr::from_be_bytes_mod_order(&[ - 70, 3, 196, 3, 187, 250, 154, 23, 115, 138, 92, 98, 120, 234, 171, 28, 55, 236, 48, - 176, 115, 122, 162, 64, 159, 196, 137, 128, 105, 235, 152, 60, - ]), - 70 => BlsFr::from_be_bytes_mod_order(&[ - 104, 148, 231, 226, 43, 44, 29, 92, 112, 167, 18, 166, 52, 90, 230, 177, 146, 169, 200, - 51, 169, 35, 76, 49, 197, 106, 172, 209, 107, 194, 241, 0, - ]), - 71 => BlsFr::from_be_bytes_mod_order(&[ - 91, 226, 203, 188, 68, 5, 58, 208, 138, 250, 77, 30, 171, 199, 243, 210, 49, 238, 167, - 153, 185, 63, 34, 110, 144, 91, 125, 77, 101, 197, 142, 187, - ]), - 72 => BlsFr::from_be_bytes_mod_order(&[ - 88, 229, 95, 40, 123, 69, 58, 152, 8, 98, 74, 140, 42, 53, 61, 82, 141, 160, 247, 231, - 19, 165, 198, 208, 215, 113, 30, 71, 6, 63, 166, 17, - ]), - 73 => BlsFr::from_be_bytes_mod_order(&[ - 54, 110, 191, 175, 163, 173, 56, 28, 14, 226, 88, 201, 184, 253, 252, 205, 184, 104, - 167, 215, 225, 241, 246, 154, 43, 93, 252, 197, 87, 37, 85, 223, - ]), - 74 => BlsFr::from_be_bytes_mod_order(&[ - 69, 118, 106, 183, 40, 150, 140, 100, 47, 144, 217, 124, 207, 85, 4, 221, 193, 5, 24, - 168, 25, 235, 188, 196, 208, 156, 63, 93, 120, 77, 103, 206, - ]), - 75 => BlsFr::from_be_bytes_mod_order(&[ - 57, 103, 143, 101, 81, 47, 30, 228, 4, 219, 48, 36, 244, 29, 63, 86, 126, 246, 109, - 137, 208, 68, 208, 34, 230, 188, 34, 158, 149, 188, 118, 177, - ]), - 76 => BlsFr::from_be_bytes_mod_order(&[ - 70, 58, 237, 29, 47, 31, 149, 94, 48, 120, 190, 91, 247, 191, 196, 111, 192, 235, 140, - 81, 85, 25, 6, 168, 134, 143, 24, 255, 174, 48, 207, 79, - ]), - 77 => BlsFr::from_be_bytes_mod_order(&[ - 33, 102, 143, 1, 106, 128, 99, 192, 213, 139, 119, 80, 163, 188, 47, 225, 207, 130, - 194, 95, 153, 220, 1, 164, 229, 52, 200, 143, 229, 61, 133, 254, - ]), - 78 => BlsFr::from_be_bytes_mod_order(&[ - 57, 208, 9, 148, 168, 165, 4, 106, 27, 199, 73, 54, 62, 152, 167, 104, 227, 77, 234, - 86, 67, 159, 225, 149, 75, 239, 66, 155, 197, 51, 22, 8, - ]), - 79 => BlsFr::from_be_bytes_mod_order(&[ - 77, 127, 93, 205, 120, 236, 233, 169, 51, 152, 77, 227, 44, 11, 72, 250, 194, 187, 169, - 31, 38, 25, 150, 184, 233, 209, 2, 23, 115, 189, 7, 204, - ]), - _ => BlsFr::ZERO, - } -} - -/// A channel that can be used to draw random elements from a PoseidonBLS hash. -#[derive(Clone, Default)] -pub struct PoseidonBLSChannel { - digest: BlsFr, - pub channel_time: ChannelTime, -} - -pub fn poseidon_hash_bls(x: BlsFr, y: BlsFr) -> BlsFr { - let mut state = [x, y, BlsFr::ZERO]; - poseidon_permute_comp_bls(&mut state); - state[0] + x -} - -pub fn poseidon_permute_comp_bls(state: &mut [BlsFr; 3]) { - let mut idx = 0; - mix(state); - - // Full rounds - for _ in 0..4 { - round_comp(state, idx, true); - idx += 3; - } - - // Partial rounds - for _ in 0..56 { - round_comp(state, idx, false); - idx += 1; - } - - // Full rounds - for _ in 0..4 { - round_comp(state, idx, true); - idx += 3; - } -} - -#[inline] -fn round_comp(state: &mut [BlsFr; 3], idx: usize, full: bool) { - if full { - state[0] += poseidon_comp_consts(idx); - state[1] += poseidon_comp_consts(idx + 1); - state[2] += poseidon_comp_consts(idx + 2); - // Optimize multiplication - state[0] = state[0] * state[0] * state[0] * state[0] * state[0]; - state[1] = state[1] * state[1] * state[1] * state[1] * state[1]; - state[2] = state[2] * state[2] * state[2] * state[2] * state[2]; - } else { - state[0] += poseidon_comp_consts(idx); - state[2] = state[2] * state[2] * state[2] * state[2] * state[2]; - } - mix(state); -} - -#[inline(always)] -fn mix(state: &mut [BlsFr; 3]) { - state[0] = state[0] + state[1] + state[2]; - state[1] = state[0] + state[1]; - state[2] = state[0] + state[2]; -} - -pub fn poseidon_hash_many_bls(msgs: &[BlsFr]) -> BlsFr { - let mut state = [BlsFr::ZERO, BlsFr::ZERO, BlsFr::ZERO]; - let mut iter = msgs.chunks_exact(2); - - for msg in iter.by_ref() { - state[0] += msg[0]; - state[1] += msg[1]; - poseidon_permute_comp_bls(&mut state); - } - let r = iter.remainder(); - if r.len() == 1 { - state[0] += r[0]; - } - state[r.len()] += BlsFr::ONE; - poseidon_permute_comp_bls(&mut state); - - state[0] -} - -impl PoseidonBLSChannel { - pub fn digest(&self) -> BlsFr { - self.digest - } - pub fn update_digest(&mut self, new_digest: BlsFr) { - self.digest = new_digest; - self.channel_time.inc_challenges(); - } - fn draw_felt252(&mut self) -> BlsFr { - let res = poseidon_hash_bls(self.digest, BlsFr::from(self.channel_time.n_sent as u64)); - self.channel_time.inc_sent(); - res - } - - // TODO(spapini): Understand if we really need uniformity here. - /// Generates a close-to uniform random vector of BaseField elements. - fn draw_base_felts(&mut self) -> [BaseField; 8] { - let shift = NonZero::new(U256::from_u64(1u64 << 31)).unwrap(); - - let mut cur = self.draw_felt252(); - let u32s: [u32; 8usize] = std::array::from_fn(|_| { - let (quotient, reminder) = - U256::from_be_slice(&cur.into_bigint().to_bytes_be()).div_rem(&shift); - cur = BlsFr::from_be_bytes_mod_order("ient.to_be_bytes()); - u32::from_str_radix(&reminder.to_string(),16).unwrap() - }); - - u32s.into_iter() - .map(|x| BaseField::reduce(x as u64)) - .collect::>() - .try_into() - .unwrap() - } -} - -impl Channel for PoseidonBLSChannel { - const BYTES_PER_HASH: usize = BYTES_PER_FELT252; - - fn trailing_zeros(&self) -> u32 { - let bytes = self.digest.into_bigint().to_bytes_be(); - u128::from_le_bytes(std::array::from_fn(|i| bytes[i])).trailing_zeros() - } - - // TODO(spapini): Optimize. - fn mix_felts(&mut self, felts: &[SecureField]) { - let shift = BlsFr::from(1u64 << 31); - let mut res = Vec::with_capacity(felts.len() / 2 + 2); - res.push(self.digest); - for chunk in felts.chunks(2) { - res.push( - chunk - .iter() - .flat_map(|x| x.to_m31_array()) - .fold(BlsFr::default(), |cur, y| { - cur * shift + BlsFr::from_be_bytes_mod_order(&y.0.to_be_bytes()) - }), - ); - } - - // TODO(spapini): do we need length padding? - self.update_digest(poseidon_hash_many_bls(&res)); - } - - fn mix_nonce(&mut self, nonce: u64) { - self.update_digest(poseidon_hash_bls(self.digest, nonce.into())); - } - - fn draw_felt(&mut self) -> SecureField { - let felts: [BaseField; FELTS_PER_HASH] = self.draw_base_felts(); - SecureField::from_m31_array(felts[..SECURE_EXTENSION_DEGREE].try_into().unwrap()) - } - - fn draw_felts(&mut self, n_felts: usize) -> Vec { - let mut felts = iter::from_fn(|| Some(self.draw_base_felts())).flatten(); - let secure_felts = iter::from_fn(|| { - Some(SecureField::from_m31_array([ - felts.next()?, - felts.next()?, - felts.next()?, - felts.next()?, - ])) - }); - secure_felts.take(n_felts).collect() - } - - fn draw_random_bytes(&mut self) -> Vec { - let shift = NonZero::new(U256::from_u64(1u64 << 8)).unwrap(); - let mut cur = self.draw_felt252(); - let bytes: [u8; 31] = std::array::from_fn(|_| { - let (quotient, reminder) = - U256::from_be_slice(&cur.into_bigint().to_bytes_be()).div_rem(&shift); - cur = BlsFr::from_be_bytes_mod_order("ient.to_be_bytes()); - u8::from_str_radix(&reminder.to_string(),16).unwrap() - }); - bytes.to_vec() - } -} - -#[cfg(test)] -mod tests { - use std::collections::BTreeSet; - - use crate::core::channel::poseidon_bls::PoseidonBLSChannel; - use crate::core::channel::Channel; - use crate::core::fields::qm31::SecureField; - use crate::m31; - - #[test] - fn test_channel_time() { - let mut channel = PoseidonBLSChannel::default(); - - assert_eq!(channel.channel_time.n_challenges, 0); - assert_eq!(channel.channel_time.n_sent, 0); - - channel.draw_random_bytes(); - assert_eq!(channel.channel_time.n_challenges, 0); - assert_eq!(channel.channel_time.n_sent, 1); - - channel.draw_felts(9); - assert_eq!(channel.channel_time.n_challenges, 0); - assert_eq!(channel.channel_time.n_sent, 6); - } - - #[test] - fn test_draw_random_bytes() { - let mut channel = PoseidonBLSChannel::default(); - - let first_random_bytes = channel.draw_random_bytes(); - - // Assert that next random bytes are different. - assert_ne!(first_random_bytes, channel.draw_random_bytes()); - } - - #[test] - pub fn test_draw_felt() { - let mut channel = PoseidonBLSChannel::default(); - - let first_random_felt = channel.draw_felt(); - - // Assert that next random felt is different. - assert_ne!(first_random_felt, channel.draw_felt()); - } - - #[test] - pub fn test_draw_felts() { - let mut channel = PoseidonBLSChannel::default(); - - let mut random_felts = channel.draw_felts(5); - random_felts.extend(channel.draw_felts(4)); - - // Assert that all the random felts are unique. - assert_eq!( - random_felts.len(), - random_felts.iter().collect::>().len() - ); - } - - #[test] - pub fn test_mix_felts() { - let mut channel = PoseidonBLSChannel::default(); - let initial_digest = channel.digest; - let felts: Vec = (0..2) - .map(|i| SecureField::from(m31!(i + 1923782))) - .collect(); - - channel.mix_felts(felts.as_slice()); - - assert_ne!(initial_digest, channel.digest); - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/circle.rs b/Stwo_wrapper/crates/prover/src/core/circle.rs deleted file mode 100644 index 8804840..0000000 --- a/Stwo_wrapper/crates/prover/src/core/circle.rs +++ /dev/null @@ -1,561 +0,0 @@ -use std::ops::{Add, Div, Mul, Neg, Sub}; - -use num_traits::{One, Zero}; - -use super::fields::m31::{BaseField, M31}; -use super::fields::qm31::SecureField; -use super::fields::{ComplexConjugate, Field, FieldExpOps}; -use crate::core::channel::Channel; -use crate::core::fields::qm31::P4; -use crate::math::utils::egcd; - -/// A point on the complex circle. Treated as an additive group. -#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct CirclePoint { - pub x: F, - pub y: F, -} - -impl + FieldExpOps + Sub + Neg> CirclePoint { - pub fn zero() -> Self { - Self { - x: F::one(), - y: F::zero(), - } - } - - pub fn double(&self) -> Self { - *self + *self - } - - /// Applies the circle's x-coordinate doubling map. - /// - /// # Examples - /// - /// ``` - /// use stwo_prover::core::circle::{CirclePoint, M31_CIRCLE_GEN}; - /// use stwo_prover::core::fields::m31::M31; - /// let p = M31_CIRCLE_GEN.mul(17); - /// assert_eq!(CirclePoint::double_x(p.x), (p + p).x); - /// ``` - pub fn double_x(x: F) -> F { - let sx = x.square(); - sx + sx - F::one() - } - - /// Returns the log order of a point. - /// - /// All points have an order of the form `2^k`. - /// - /// # Examples - /// - /// ``` - /// use stwo_prover::core::circle::{CirclePoint, M31_CIRCLE_GEN, M31_CIRCLE_LOG_ORDER}; - /// use stwo_prover::core::fields::m31::M31; - /// assert_eq!(M31_CIRCLE_GEN.log_order(), M31_CIRCLE_LOG_ORDER); - /// ``` - pub fn log_order(&self) -> u32 - where - F: PartialEq + Eq, - { - // we only need the x-coordinate to check order since the only point - // with x=1 is the circle's identity - let mut res = 0; - let mut cur = self.x; - while cur != F::one() { - cur = Self::double_x(cur); - res += 1; - } - res - } - - pub fn mul(&self, mut scalar: u128) -> CirclePoint { - let mut res = Self::zero(); - let mut cur = *self; - while scalar > 0 { - if scalar & 1 == 1 { - res = res + cur; - } - cur = cur.double(); - scalar >>= 1; - } - res - } - - pub fn repeated_double(&self, n: u32) -> Self { - let mut res = *self; - for _ in 0..n { - res = res.double(); - } - res - } - - pub fn conjugate(&self) -> CirclePoint { - Self { - x: self.x, - y: -self.y, - } - } - - pub fn antipode(&self) -> CirclePoint { - Self { - x: -self.x, - y: -self.y, - } - } - - pub fn into_ef>(&self) -> CirclePoint { - CirclePoint { - x: self.x.into(), - y: self.y.into(), - } - } - - pub fn mul_signed(&self, off: isize) -> CirclePoint { - if off > 0 { - self.mul(off as u128) - } else { - self.conjugate().mul(-off as u128) - } - } -} - -impl + FieldExpOps + Sub + Neg> Add - for CirclePoint -{ - type Output = Self; - - fn add(self, rhs: Self) -> Self::Output { - let x = self.x * rhs.x - self.y * rhs.y; - let y = self.x * rhs.y + self.y * rhs.x; - Self { x, y } - } -} - -impl + FieldExpOps + Sub + Neg> Neg - for CirclePoint -{ - type Output = Self; - - fn neg(self) -> Self::Output { - self.conjugate() - } -} - -impl + FieldExpOps + Sub + Neg> Sub - for CirclePoint -{ - type Output = Self; - - fn sub(self, rhs: Self) -> Self::Output { - self + (-rhs) - } -} - -impl ComplexConjugate for CirclePoint { - fn complex_conjugate(&self) -> Self { - Self { - x: self.x.complex_conjugate(), - y: self.y.complex_conjugate(), - } - } -} - -impl CirclePoint { - pub fn get_point(index: u128) -> Self { - assert!(index < SECURE_FIELD_CIRCLE_ORDER); - SECURE_FIELD_CIRCLE_GEN.mul(index) - } - - pub fn get_random_point(channel: &mut C) -> Self { - let t = channel.draw_felt(); - let t_square = t.square(); - - let one_plus_tsquared_inv = t_square.add(SecureField::one()).inverse(); - - let x = SecureField::one() - .add(t_square.neg()) - .mul(one_plus_tsquared_inv); - let y = t.double().mul(one_plus_tsquared_inv); - - Self { x, y } - } -} - -/// A generator for the circle group over [M31]. -/// -/// # Examples -/// -/// ``` -/// use stwo_prover::core::circle::{CirclePoint, M31_CIRCLE_GEN}; -/// use stwo_prover::core::fields::m31::M31; -/// -/// // Adding a generator to itself (2^30) times should NOT yield the identity. -/// let circle_point = M31_CIRCLE_GEN.repeated_double(30); -/// assert_ne!(circle_point, CirclePoint::zero()); -/// -/// // Shown above ord(M31_CIRCLE_GEN) > 2^30 . Group order is 2^31. -/// // Ord(M31_CIRCLE_GEN) must be a divisor of it, Hence ord(M31_CIRCLE_GEN) = 2^31. -/// // Adding the generator to itself (2^31) times should yield the identity. -/// let circle_point = M31_CIRCLE_GEN.repeated_double(31); -/// assert_eq!(circle_point, CirclePoint::zero()); -/// ``` -pub const M31_CIRCLE_GEN: CirclePoint = CirclePoint { - x: M31::from_u32_unchecked(2), - y: M31::from_u32_unchecked(1268011823), -}; - -/// Order of [M31_CIRCLE_GEN]. -pub const M31_CIRCLE_LOG_ORDER: u32 = 31; - -/// A generator for the circle group over [SecureField]. -pub const SECURE_FIELD_CIRCLE_GEN: CirclePoint = CirclePoint { - x: SecureField::from_u32_unchecked(1, 0, 478637715, 513582971), - y: SecureField::from_u32_unchecked(992285211, 649143431, 740191619, 1186584352), -}; - -/// Order of [SECURE_FIELD_CIRCLE_GEN]. -pub const SECURE_FIELD_CIRCLE_ORDER: u128 = P4 - 1; - -/// Integer i that represent the circle point i * CIRCLE_GEN. Treated as an -/// additive ring modulo `1 << M31_CIRCLE_LOG_ORDER`. -#[derive(Copy, Clone, Debug, PartialEq, Eq, Ord, PartialOrd)] -pub struct CirclePointIndex(pub usize); - -impl CirclePointIndex { - pub fn zero() -> Self { - Self(0) - } - - pub fn generator() -> Self { - Self(1) - } - - pub fn reduce(self) -> Self { - Self(self.0 & ((1 << M31_CIRCLE_LOG_ORDER) - 1)) - } - - pub fn subgroup_gen(log_size: u32) -> Self { - assert!(log_size <= M31_CIRCLE_LOG_ORDER); - Self(1 << (M31_CIRCLE_LOG_ORDER - log_size)) - } - - pub fn to_point(self) -> CirclePoint { - M31_CIRCLE_GEN.mul(self.0 as u128) - } - - pub fn half(self) -> Self { - assert!(self.0 & 1 == 0); - Self(self.0 >> 1) - } - - pub fn try_div(&self, rhs: CirclePointIndex) -> Option { - // Find x s.t. x * rhs.0 = self.0 (mod CIRCLE_ORDER). - let (s, _t, g) = egcd(rhs.0 as isize, 1 << M31_CIRCLE_LOG_ORDER); - if self.0 as isize % g != 0 { - return None; - } - let res = s * self.0 as isize / g; - Some(res) - } -} - -impl Add for CirclePointIndex { - type Output = Self; - - fn add(self, rhs: Self) -> Self::Output { - Self(self.0 + rhs.0).reduce() - } -} - -impl Sub for CirclePointIndex { - type Output = Self; - - fn sub(self, rhs: Self) -> Self::Output { - Self(self.0 + (1 << M31_CIRCLE_LOG_ORDER) - rhs.0).reduce() - } -} - -impl Mul for CirclePointIndex { - type Output = Self; - - fn mul(self, rhs: usize) -> Self::Output { - Self(self.0.wrapping_mul(rhs)).reduce() - } -} - -impl Div for CirclePointIndex { - type Output = isize; - - fn div(self, rhs: Self) -> Self::Output { - self.try_div(rhs).unwrap() - } -} - -impl Neg for CirclePointIndex { - type Output = Self; - - fn neg(self) -> Self::Output { - Self((1 << M31_CIRCLE_LOG_ORDER) - self.0).reduce() - } -} - -/// Represents the coset initial + \. -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub struct Coset { - pub initial_index: CirclePointIndex, - pub initial: CirclePoint, - pub step_size: CirclePointIndex, - pub step: CirclePoint, - pub log_size: u32, -} - -impl Coset { - pub fn new(initial_index: CirclePointIndex, log_size: u32) -> Self { - assert!(log_size <= M31_CIRCLE_LOG_ORDER); - let step_size = CirclePointIndex::subgroup_gen(log_size); - Self { - initial_index, - initial: initial_index.to_point(), - step: step_size.to_point(), - step_size, - log_size, - } - } - - /// Creates a coset of the form . - /// For example, for n=8, we get the point indices \[0,1,2,3,4,5,6,7\]. - pub fn subgroup(log_size: u32) -> Self { - Self::new(CirclePointIndex::zero(), log_size) - } - - /// Creates a coset of the form G_2n + \. - /// For example, for n=8, we get the point indices \[1,3,5,7,9,11,13,15\]. - pub fn odds(log_size: u32) -> Self { - Self::new(CirclePointIndex::subgroup_gen(log_size + 1), log_size) - } - - /// Creates a coset of the form G_4n + . - /// For example, for n=8, we get the point indices \[1,5,9,13,17,21,25,29\]. - /// Its conjugate will be \[3,7,11,15,19,23,27,31\]. - pub fn half_odds(log_size: u32) -> Self { - Self::new(CirclePointIndex::subgroup_gen(log_size + 2), log_size) - } - - /// Returns the size of the coset. - pub fn size(&self) -> usize { - 1 << self.log_size() - } - - /// Returns the log size of the coset. - pub fn log_size(&self) -> u32 { - self.log_size - } - - pub fn iter(&self) -> CosetIterator> { - CosetIterator { - cur: self.initial, - step: self.step, - remaining: self.size(), - } - } - - pub fn iter_indices(&self) -> CosetIterator { - CosetIterator { - cur: self.initial_index, - step: self.step_size, - remaining: self.size(), - } - } - - /// Returns a new coset comprising of all points in current coset doubled. - pub fn double(&self) -> Self { - assert!(self.log_size > 0); - Self { - initial_index: self.initial_index * 2, - initial: self.initial.double(), - step: self.step.double(), - step_size: self.step_size * 2, - log_size: self.log_size.saturating_sub(1), - } - } - - pub fn repeated_double(&self, n_doubles: u32) -> Self { - (0..n_doubles).fold(*self, |coset, _| coset.double()) - } - - pub fn is_doubling_of(&self, other: Self) -> bool { - self.log_size <= other.log_size - && *self == other.repeated_double(other.log_size - self.log_size) - } - - pub fn initial(&self) -> CirclePoint { - self.initial - } - - pub fn index_at(&self, index: usize) -> CirclePointIndex { - self.initial_index + self.step_size.mul(index) - } - - pub fn at(&self, index: usize) -> CirclePoint { - self.index_at(index).to_point() - } - - pub fn shift(&self, shift_size: CirclePointIndex) -> Self { - let initial_index = self.initial_index + shift_size; - Self { - initial_index, - initial: initial_index.to_point(), - ..*self - } - } - - /// Creates the conjugate coset: -initial -\. - pub fn conjugate(&self) -> Self { - let initial_index = -self.initial_index; - let step_size = -self.step_size; - Self { - initial_index, - initial: initial_index.to_point(), - step_size, - step: step_size.to_point(), - log_size: self.log_size, - } - } - - pub fn find(&self, i: CirclePointIndex) -> Option { - let res = (i - self.initial_index).try_div(self.step_size)?; - Some(res.rem_euclid(self.size() as isize) as usize) - } -} - -impl IntoIterator for Coset { - type Item = CirclePoint; - type IntoIter = CosetIterator>; - - /// Iterates over the points in the coset. - fn into_iter(self) -> Self::IntoIter { - self.iter() - } -} - -#[derive(Clone)] -pub struct CosetIterator { - pub cur: T, - pub step: T, - pub remaining: usize, -} - -impl + Copy> Iterator for CosetIterator { - type Item = T; - - fn next(&mut self) -> Option { - if self.remaining == 0 { - return None; - } - self.remaining -= 1; - let res = self.cur; - self.cur = self.cur + self.step; - Some(res) - } -} - -#[cfg(test)] -mod tests { - use std::collections::BTreeSet; - - use num_traits::{One, Pow}; - - use super::{CirclePointIndex, Coset}; - use crate::core::channel::Blake2sChannel; - use crate::core::circle::{CirclePoint, SECURE_FIELD_CIRCLE_GEN}; - use crate::core::fields::qm31::{SecureField, P4}; - use crate::core::fields::FieldExpOps; - use crate::core::poly::circle::CanonicCoset; - - #[test] - fn test_iterator() { - let coset = Coset::new(CirclePointIndex(1), 3); - let actual_indices: Vec<_> = coset.iter_indices().collect(); - let expected_indices = vec![ - CirclePointIndex(1), - CirclePointIndex(1) + CirclePointIndex::subgroup_gen(3) * 1, - CirclePointIndex(1) + CirclePointIndex::subgroup_gen(3) * 2, - CirclePointIndex(1) + CirclePointIndex::subgroup_gen(3) * 3, - CirclePointIndex(1) + CirclePointIndex::subgroup_gen(3) * 4, - CirclePointIndex(1) + CirclePointIndex::subgroup_gen(3) * 5, - CirclePointIndex(1) + CirclePointIndex::subgroup_gen(3) * 6, - CirclePointIndex(1) + CirclePointIndex::subgroup_gen(3) * 7, - ]; - assert_eq!(actual_indices, expected_indices); - - let actual_points = coset.iter().collect::>(); - let expected_points: Vec<_> = expected_indices.iter().map(|i| i.to_point()).collect(); - assert_eq!(actual_points, expected_points); - } - - #[test] - fn test_coset_is_half_coset_with_conjugate() { - let canonic_coset = CanonicCoset::new(8); - let coset_points = BTreeSet::from_iter(canonic_coset.coset().iter()); - - let half_coset_points = BTreeSet::from_iter(canonic_coset.half_coset().iter()); - let half_coset_conjugate_points = - BTreeSet::from_iter(canonic_coset.half_coset().conjugate().iter()); - - assert!((&half_coset_points & &half_coset_conjugate_points).is_empty()); - assert_eq!( - coset_points, - &half_coset_points | &half_coset_conjugate_points - ) - } - - #[test] - pub fn test_get_random_circle_point() { - let mut channel = Blake2sChannel::default(); - - let first_random_circle_point = CirclePoint::get_random_point(&mut channel); - - // Assert that the next random circle point is different. - assert_ne!( - first_random_circle_point, - CirclePoint::get_random_point(&mut channel) - ); - } - - #[test] - pub fn test_secure_field_circle_gen() { - let prime_factors = [ - (2, 33), - (3, 2), - (5, 1), - (7, 1), - (11, 1), - (31, 1), - (151, 1), - (331, 1), - (733, 1), - (1709, 1), - (368140581013, 1), - ]; - - assert_eq!( - prime_factors - .iter() - .map(|(p, e)| p.pow(*e as u32)) - .product::(), - P4 - 1 - ); - assert_eq!( - SECURE_FIELD_CIRCLE_GEN.x.square() + SECURE_FIELD_CIRCLE_GEN.y.square(), - SecureField::one() - ); - assert_eq!(SECURE_FIELD_CIRCLE_GEN.mul(P4 - 1), CirclePoint::zero()); - for (p, _) in prime_factors.iter() { - assert_ne!( - SECURE_FIELD_CIRCLE_GEN.mul((P4 - 1) / *p), - CirclePoint::zero() - ); - } - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/constraints.rs b/Stwo_wrapper/crates/prover/src/core/constraints.rs deleted file mode 100644 index 776951f..0000000 --- a/Stwo_wrapper/crates/prover/src/core/constraints.rs +++ /dev/null @@ -1,251 +0,0 @@ -use num_traits::One; - -use super::circle::{CirclePoint, Coset}; -use super::fields::m31::BaseField; -use super::fields::qm31::SecureField; -use super::fields::ExtensionOf; -use super::pcs::quotients::PointSample; -use crate::core::fields::ComplexConjugate; - -/// Evaluates a vanishing polynomial of the coset at a point. -pub fn coset_vanishing>(coset: Coset, mut p: CirclePoint) -> F { - // Doubling a point `log_order - 1` times and taking the x coordinate is - // essentially evaluating a polynomial in x of degree `2^(log_order - 1)`. If - // the entire `2^log_order` points of the coset are roots (i.e. yield 0), then - // this is a vanishing polynomial of these points. - - // Rotating the coset -coset.initial + step / 2 yields a canonic coset: - // `step/2 + .` - // Doubling this coset log_order - 1 times yields the coset +-G_4. - // The polynomial x vanishes on these points. - // ```text - // X - // . . - // X - // ``` - p = p - coset.initial.into_ef() + coset.step_size.half().to_point().into_ef(); - //println!("p - coset.initial.into_ef() + coset.step_size.half().to_point().into_ef() = {:?} - {:?} + {:?}",p,coset.initial,coset.step_size.half().0); - let mut x = p.x; - - // The formula for the x coordinate of the double of a point. - for _ in 1..coset.log_size { - x = CirclePoint::double_x(x); - } - x -} - -/// Evaluates the polynomial that is used to exclude the excluded point at point -/// p. Note that this polynomial has a zero of multiplicity 2 at the excluded -/// point. -pub fn point_excluder>( - excluded: CirclePoint, - p: CirclePoint, -) -> F { - (p - excluded.into_ef()).x - BaseField::one() -} - -// A vanishing polynomial on 2 circle points. -pub fn pair_vanishing>( - excluded0: CirclePoint, - excluded1: CirclePoint, - p: CirclePoint, -) -> F { - // The algorithm check computes the area of the triangle formed by the - // 3 points. This is done using the determinant of: - // | p.x p.y 1 | - // | e0.x e0.y 1 | - // | e1.x e1.y 1 | - // This is a polynomial of degree 1 in p.x and p.y, and thus it is a line. - // It vanishes at e0 and e1. - (excluded0.y - excluded1.y) * p.x - + (excluded1.x - excluded0.x) * p.y - + (excluded0.x * excluded1.y - excluded0.y * excluded1.x) -} - -/// Evaluates a vanishing polynomial of the vanish_point at a point. -/// Note that this function has a pole on the antipode of the vanish_point. -pub fn point_vanishing, EF: ExtensionOf>( - vanish_point: CirclePoint, - p: CirclePoint, -) -> EF { - let h = p - vanish_point.into_ef(); - h.y / (EF::one() + h.x) -} - -/// Evaluates a point on a line between a point and its complex conjugate. -/// Relies on the fact that every polynomial F over the base field holds: -/// F(p*) == F(p)* (* being the complex conjugate). -pub fn complex_conjugate_line( - point: CirclePoint, - value: SecureField, - p: CirclePoint, -) -> SecureField { - // TODO(AlonH): This assertion will fail at a probability of 1 to 2^62. Use a better solution. - assert_ne!( - point.y, - point.y.complex_conjugate(), - "Cannot evaluate a line with a single point ({point:?})." - ); - value - + (value.complex_conjugate() - value) * (-point.y + p.y) - / (point.complex_conjugate().y - point.y) -} - -/// Evaluates the coefficients of a line between a point and its complex conjugate. Specifically, -/// `a, b, and c, s.t. a*x + b -c*y = 0` for (x,y) being (sample.y, sample.value) and -/// (conj(sample.y), conj(sample.value)). -/// Relies on the fact that every polynomial F over the base -/// field holds: F(p*) == F(p)* (* being the complex conjugate). -pub fn complex_conjugate_line_coeffs( - sample: &PointSample, - alpha: SecureField, -) -> (SecureField, SecureField, SecureField) { - // TODO(AlonH): This assertion will fail at a probability of 1 to 2^62. Use a better solution. - assert_ne!( - sample.point.y, - sample.point.y.complex_conjugate(), - "Cannot evaluate a line with a single point ({:?}).", - sample.point - ); - let a = sample.value.complex_conjugate() - sample.value; - let c = sample.point.complex_conjugate().y - sample.point.y; - let b = sample.value * c - a * sample.point.y; - (alpha * a, alpha * b, alpha * c) -} - -#[cfg(test)] -mod tests { - use num_traits::Zero; - - use super::{coset_vanishing, point_excluder, point_vanishing}; - use crate::core::backend::cpu::{CpuCircleEvaluation, CpuCirclePoly}; - use crate::core::circle::{CirclePoint, CirclePointIndex, Coset}; - use crate::core::constraints::{complex_conjugate_line, pair_vanishing}; - use crate::core::fields::m31::{BaseField, M31}; - use crate::core::fields::qm31::SecureField; - use crate::core::fields::{ComplexConjugate, FieldExpOps}; - use crate::core::poly::circle::CanonicCoset; - use crate::core::poly::NaturalOrder; - use crate::core::test_utils::secure_eval_to_base_eval; - use crate::m31; - - #[test] - fn test_coset_vanishing() { - let cosets = [ - Coset::half_odds(5), - Coset::odds(5), - Coset::new(CirclePointIndex::zero(), 5), - Coset::half_odds(5).conjugate(), - ]; - for c0 in cosets.iter() { - for el in c0.iter() { - assert_eq!(coset_vanishing(*c0, el), BaseField::zero()); - for c1 in cosets.iter() { - if c0 == c1 { - continue; - } - assert_ne!(coset_vanishing(*c1, el), BaseField::zero()); - } - } - } - } - - #[test] - fn test_point_excluder() { - let excluded = Coset::half_odds(5).at(10); - let point = (CirclePointIndex::generator() * 4).to_point(); - - let num = point_excluder(excluded, point) * point_excluder(excluded.conjugate(), point); - let denom = (point.x - excluded.x).pow(2); - - assert_eq!(num, denom); - } - - #[test] - fn test_pair_excluder() { - let excluded0 = Coset::half_odds(5).at(10); - let excluded1 = Coset::half_odds(5).at(13); - let point = (CirclePointIndex::generator() * 4).to_point(); - - assert_ne!(pair_vanishing(excluded0, excluded1, point), M31::zero()); - assert_eq!(pair_vanishing(excluded0, excluded1, excluded0), M31::zero()); - assert_eq!(pair_vanishing(excluded0, excluded1, excluded1), M31::zero()); - } - - #[test] - fn test_point_vanishing_success() { - let coset = Coset::odds(5); - let vanish_point = coset.at(2); - for el in coset.iter() { - if el == vanish_point { - assert_eq!(point_vanishing(vanish_point, el), BaseField::zero()); - continue; - } - if el == vanish_point.antipode() { - continue; - } - assert_ne!(point_vanishing(vanish_point, el), BaseField::zero()); - } - } - - #[test] - #[should_panic(expected = "0 has no inverse")] - fn test_point_vanishing_failure() { - let coset = Coset::half_odds(6); - let point = coset.at(4); - point_vanishing(point, point.antipode()); - } - - #[test] - fn test_complex_conjugate_symmetry() { - // Create a polynomial over a base circle domain. - let polynomial = CpuCirclePoly::new((0..1 << 7).map(|i| m31!(i)).collect()); - let oods_point = CirclePoint::get_point(9834759221); - - // Assert that the base field polynomial is complex conjugate symmetric. - assert_eq!( - polynomial.eval_at_point(oods_point.complex_conjugate()), - polynomial.eval_at_point(oods_point).complex_conjugate() - ); - } - - #[test] - fn test_point_vanishing_degree() { - // Create a polynomial over a circle domain. - let log_domain_size = 7; - let domain_size = 1 << log_domain_size; - let polynomial = CpuCirclePoly::new((0..domain_size).map(|i| m31!(i)).collect()); - - // Create a larger domain. - let log_large_domain_size = log_domain_size + 1; - let large_domain_size = 1 << log_large_domain_size; - let large_domain = CanonicCoset::new(log_large_domain_size).circle_domain(); - - // Create a vanish point that is not in the large domain. - let vanish_point = CirclePoint::get_point(97); - let vanish_point_value = polynomial.eval_at_point(vanish_point); - - // Compute the quotient polynomial. - let mut quotient_polynomial_values = Vec::with_capacity(large_domain_size as usize); - for point in large_domain.iter() { - let line = complex_conjugate_line(vanish_point, vanish_point_value, point); - let mut value = polynomial.eval_at_point(point.into_ef()) - line; - value /= pair_vanishing( - vanish_point, - vanish_point.complex_conjugate(), - point.into_ef(), - ); - quotient_polynomial_values.push(value); - } - let quotient_evaluation = CpuCircleEvaluation::::new( - large_domain, - quotient_polynomial_values, - ); - let quotient_polynomial = secure_eval_to_base_eval("ient_evaluation) - .bit_reverse() - .interpolate(); - - // Check that the quotient polynomial is indeed in the wanted fft space. - assert!(quotient_polynomial.is_in_fft_space(log_domain_size)); - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/fft.rs b/Stwo_wrapper/crates/prover/src/core/fft.rs deleted file mode 100644 index 630fbe7..0000000 --- a/Stwo_wrapper/crates/prover/src/core/fft.rs +++ /dev/null @@ -1,21 +0,0 @@ -use std::ops::{Add, AddAssign, Mul, Sub}; - -use super::fields::m31::BaseField; - -pub fn butterfly(v0: &mut F, v1: &mut F, twid: BaseField) -where - F: Copy + AddAssign + Sub + Mul, -{ - let tmp = *v1 * twid; - *v1 = *v0 - tmp; - *v0 += tmp; -} - -pub fn ibutterfly(v0: &mut F, v1: &mut F, itwid: BaseField) -where - F: Copy + AddAssign + Add + Sub + Mul, -{ - let tmp = *v0; - *v0 = tmp + *v1; - *v1 = (tmp - *v1) * itwid; -} diff --git a/Stwo_wrapper/crates/prover/src/core/fields/cm31.rs b/Stwo_wrapper/crates/prover/src/core/fields/cm31.rs deleted file mode 100644 index 6f1b6c2..0000000 --- a/Stwo_wrapper/crates/prover/src/core/fields/cm31.rs +++ /dev/null @@ -1,137 +0,0 @@ -use std::fmt::{Debug, Display}; -use std::ops::{ - Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign, -}; - -use serde::{Deserialize, Serialize}; - -use super::{ComplexConjugate, FieldExpOps}; -use crate::core::fields::m31::M31; -use crate::{impl_extension_field, impl_field}; -pub const P2: u64 = 4611686014132420609; // (2 ** 31 - 1) ** 2 - -/// Complex extension field of M31. -/// Equivalent to M31\[x\] over (x^2 + 1) as the irreducible polynomial. -/// Represented as (a, b) of a + bi. -#[derive(Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)] -pub struct CM31(pub M31, pub M31); - -impl_field!(CM31, P2); -impl_extension_field!(CM31, M31); - -impl CM31 { - pub const fn from_u32_unchecked(a: u32, b: u32) -> CM31 { - Self(M31::from_u32_unchecked(a), M31::from_u32_unchecked(b)) - } - - pub fn from_m31(a: M31, b: M31) -> CM31 { - Self(a, b) - } -} - -impl Display for CM31 { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{} + {}i", self.0, self.1) - } -} - -impl Debug for CM31 { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{} + {}i", self.0, self.1) - } -} - -impl Mul for CM31 { - type Output = Self; - - fn mul(self, rhs: Self) -> Self::Output { - // (a + bi) * (c + di) = (ac - bd) + (ad + bc)i. - Self( - self.0 * rhs.0 - self.1 * rhs.1, - self.0 * rhs.1 + self.1 * rhs.0, - ) - } -} - -impl TryInto for CM31 { - type Error = (); - - fn try_into(self) -> Result { - if self.1 != M31::zero() { - return Err(()); - } - Ok(self.0) - } -} - -impl FieldExpOps for CM31 { - fn inverse(&self) -> Self { - assert!(!self.is_zero(), "0 has no inverse"); - // 1 / (a + bi) = (a - bi) / (a^2 + b^2). - Self(self.0, -self.1) * (self.0.square() + self.1.square()).inverse() - } -} - -#[cfg(test)] -#[macro_export] -macro_rules! cm31 { - ($m0:expr, $m1:expr) => { - CM31::from_u32_unchecked($m0, $m1) - }; -} - -#[cfg(test)] -mod tests { - use rand::rngs::SmallRng; - use rand::{Rng, SeedableRng}; - - use super::CM31; - use crate::core::fields::m31::P; - use crate::core::fields::{FieldExpOps, IntoSlice}; - use crate::m31; - - #[test] - fn test_inverse() { - let cm = cm31!(1, 2); - let cm_inv = cm.inverse(); - assert_eq!(cm * cm_inv, cm31!(1, 0)); - } - - #[test] - fn test_ops() { - let cm0 = cm31!(1, 2); - let cm1 = cm31!(4, 5); - let m = m31!(8); - let cm = CM31::from(m); - let cm0_x_cm1 = cm31!(P - 6, 13); - - assert_eq!(cm0 + cm1, cm31!(5, 7)); - assert_eq!(cm1 + m, cm1 + cm); - assert_eq!(cm0 * cm1, cm0_x_cm1); - assert_eq!(cm1 * m, cm1 * cm); - assert_eq!(-cm0, cm31!(P - 1, P - 2)); - assert_eq!(cm0 - cm1, cm31!(P - 3, P - 3)); - assert_eq!(cm1 - m, cm1 - cm); - assert_eq!(cm0_x_cm1 / cm1, cm31!(1, 2)); - assert_eq!(cm1 / m, cm1 / cm); - } - - #[test] - fn test_into_slice() { - let mut rng = SmallRng::seed_from_u64(0); - let x = (0..100).map(|_| rng.gen()).collect::>(); - - let slice = CM31::into_slice(&x); - - for i in 0..100 { - let corresponding_sub_slice = &slice[i * 8..(i + 1) * 8]; - assert_eq!( - x[i], - cm31!( - u32::from_le_bytes(corresponding_sub_slice[..4].try_into().unwrap()), - u32::from_le_bytes(corresponding_sub_slice[4..].try_into().unwrap()) - ) - ) - } - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/fields/m31.rs b/Stwo_wrapper/crates/prover/src/core/fields/m31.rs deleted file mode 100644 index 852f959..0000000 --- a/Stwo_wrapper/crates/prover/src/core/fields/m31.rs +++ /dev/null @@ -1,258 +0,0 @@ -use std::fmt::Display; -use std::ops::{ - Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign, -}; - -use bytemuck::{Pod, Zeroable}; -use rand::distributions::{Distribution, Standard}; -use serde::{Deserialize, Serialize}; - -use super::{ComplexConjugate, FieldExpOps}; -use crate::impl_field; -pub const MODULUS_BITS: u32 = 31; -pub const N_BYTES_FELT: usize = 4; -pub const P: u32 = 2147483647; // 2 ** 31 - 1 - -#[repr(transparent)] -#[derive( - Copy, - Clone, - Debug, - Default, - PartialEq, - Eq, - PartialOrd, - Ord, - Hash, - Pod, - Zeroable, - Serialize, - Deserialize, -)] -pub struct M31(pub u32); -pub type BaseField = M31; - -impl_field!(M31, P); - -impl M31 { - /// Returns `val % P` when `val` is in the range `[0, 2P)`. - /// - /// ``` - /// use stwo_prover::core::fields::m31::{M31, P}; - /// - /// let val = 2 * P - 19; - /// assert_eq!(M31::partial_reduce(val), M31::from(P - 19)); - /// ``` - pub fn partial_reduce(val: u32) -> Self { - Self(val.checked_sub(P).unwrap_or(val)) - } - - /// Returns `val % P` when `val` is in the range `[0, P^2)`. - /// - /// ``` - /// use stwo_prover::core::fields::m31::{M31, P}; - /// - /// let val = (P as u64).pow(2) - 19; - /// assert_eq!(M31::reduce(val), M31::from(P - 19)); - /// ``` - pub fn reduce(val: u64) -> Self { - Self((((((val >> MODULUS_BITS) + val + 1) >> MODULUS_BITS) + val) & (P as u64)) as u32) - } - - pub const fn from_u32_unchecked(arg: u32) -> Self { - Self(arg) - } -} - -impl Display for M31 { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} - -impl Add for M31 { - type Output = Self; - - fn add(self, rhs: Self) -> Self::Output { - Self::partial_reduce(self.0 + rhs.0) - } -} - -impl Neg for M31 { - type Output = Self; - - fn neg(self) -> Self::Output { - Self::partial_reduce(P - self.0) - } -} - -impl Sub for M31 { - type Output = Self; - - fn sub(self, rhs: Self) -> Self::Output { - Self::partial_reduce(self.0 + P - rhs.0) - } -} - -impl Mul for M31 { - type Output = Self; - - fn mul(self, rhs: Self) -> Self::Output { - Self::reduce((self.0 as u64) * (rhs.0 as u64)) - } -} - -impl FieldExpOps for M31 { - /// ``` - /// use num_traits::One; - /// use stwo_prover::core::fields::m31::BaseField; - /// use stwo_prover::core::fields::FieldExpOps; - /// - /// let v = BaseField::from(19); - /// assert_eq!(v.inverse() * v, BaseField::one()); - /// ``` - fn inverse(&self) -> Self { - assert!(!self.is_zero(), "0 has no inverse"); - pow2147483645(*self) - } -} - -impl ComplexConjugate for M31 { - fn complex_conjugate(&self) -> Self { - *self - } -} - -impl One for M31 { - fn one() -> Self { - Self(1) - } -} - -impl Zero for M31 { - fn zero() -> Self { - Self(0) - } - - fn is_zero(&self) -> bool { - *self == Self::zero() - } -} - -impl From for M31 { - fn from(value: usize) -> Self { - M31::reduce(value.try_into().unwrap()) - } -} - -impl From for M31 { - fn from(value: u32) -> Self { - M31::reduce(value.into()) - } -} - -impl From for M31 { - fn from(value: i32) -> Self { - M31::reduce(value.try_into().unwrap()) - } -} - -impl Distribution for Standard { - // Not intended for cryptographic use. Should only be used in tests and benchmarks. - fn sample(&self, rng: &mut R) -> M31 { - M31(rng.gen_range(0..P)) - } -} - -#[cfg(test)] -#[macro_export] -macro_rules! m31 { - ($m:expr) => { - $crate::core::fields::m31::M31::from_u32_unchecked($m) - }; -} - -/// Computes `v^((2^31-1)-2)`. -/// -/// Computes the multiplicative inverse of [`M31`] elements with 37 multiplications vs naive 60 -/// multiplications. Made generic to support both vectorized and non-vectorized implementations. -/// Multiplication tree found with [addchain](https://github.com/mmcloughlin/addchain). -/// -/// ``` -/// use stwo_prover::core::fields::m31::{pow2147483645, BaseField}; -/// use stwo_prover::core::fields::FieldExpOps; -/// -/// let v = BaseField::from(19); -/// assert_eq!(pow2147483645(v), v.pow(2147483645)); -/// ``` -pub fn pow2147483645(v: T) -> T { - let t0 = sqn::<2, T>(v) * v; - let t1 = sqn::<1, T>(t0) * t0; - let t2 = sqn::<3, T>(t1) * t0; - let t3 = sqn::<1, T>(t2) * t0; - let t4 = sqn::<8, T>(t3) * t3; - let t5 = sqn::<8, T>(t4) * t3; - sqn::<7, T>(t5) * t2 -} - -/// Computes `v^(2*n)`. -fn sqn(mut v: T) -> T { - for _ in 0..N { - v = v.square(); - } - v -} - -#[cfg(test)] -mod tests { - use rand::rngs::SmallRng; - use rand::{Rng, SeedableRng}; - - use super::{M31, P}; - use crate::core::fields::IntoSlice; - - fn mul_p(a: u32, b: u32) -> u32 { - ((a as u64 * b as u64) % P as u64) as u32 - } - - fn add_p(a: u32, b: u32) -> u32 { - (a + b) % P - } - - fn neg_p(a: u32) -> u32 { - if a == 0 { - 0 - } else { - P - a - } - } - - #[test] - fn test_basic_ops() { - let mut rng = SmallRng::seed_from_u64(0); - for _ in 0..10000 { - let x: u32 = rng.gen::() % P; - let y: u32 = rng.gen::() % P; - assert_eq!(m31!(add_p(x, y)), m31!(x) + m31!(y)); - assert_eq!(m31!(mul_p(x, y)), m31!(x) * m31!(y)); - assert_eq!(m31!(neg_p(x)), -m31!(x)); - } - } - - #[test] - fn test_into_slice() { - let mut rng = SmallRng::seed_from_u64(0); - let x = (0..100).map(|_| rng.gen()).collect::>(); - - let slice = M31::into_slice(&x); - - for i in 0..100 { - assert_eq!( - x[i], - m31!(u32::from_le_bytes( - slice[i * 4..(i + 1) * 4].try_into().unwrap() - )) - ); - } - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/fields/mod.rs b/Stwo_wrapper/crates/prover/src/core/fields/mod.rs deleted file mode 100644 index fbeefbb..0000000 --- a/Stwo_wrapper/crates/prover/src/core/fields/mod.rs +++ /dev/null @@ -1,489 +0,0 @@ -use std::fmt::{Debug, Display}; -use std::iter::{Product, Sum}; -use std::ops::{Mul, MulAssign, Neg}; - -use num_traits::{NumAssign, NumAssignOps, NumOps, One}; - -use super::backend::ColumnOps; - -pub mod cm31; -pub mod m31; -pub mod qm31; -pub mod secure_column; - -pub trait FieldOps: ColumnOps { - // TODO(Ohad): change to use a mutable slice. - fn batch_inverse(column: &Self::Column, dst: &mut Self::Column); -} - -pub trait FieldExpOps: Mul + MulAssign + Sized + One + Copy { - fn square(&self) -> Self { - (*self) * (*self) - } - - fn pow(&self, exp: u128) -> Self { - let mut res = Self::one(); - let mut base = *self; - let mut exp = exp; - while exp > 0 { - if exp & 1 == 1 { - res *= base; - } - base = base.square(); - exp >>= 1; - } - res - } - - fn inverse(&self) -> Self; - - /// Inverts a batch of elements using Montgomery's trick. - fn batch_inverse(column: &[Self], dst: &mut [Self]) { - const WIDTH: usize = 4; - let n = column.len(); - debug_assert!(dst.len() >= n); - - if n <= WIDTH || n % WIDTH != 0 { - batch_inverse_classic(column, dst); - return; - } - - // First pass. Compute 'WIDTH' cumulative products in an interleaving fashion, reducing - // instruction dependency and allowing better pipelining. - let mut cum_prod: [Self; WIDTH] = [Self::one(); WIDTH]; - dst[..WIDTH].copy_from_slice(&cum_prod); - for i in 0..n { - cum_prod[i % WIDTH] *= column[i]; - dst[i] = cum_prod[i % WIDTH]; - } - - // Inverse cumulative products. - // Use classic batch inversion. - let mut tail_inverses = [Self::one(); WIDTH]; - batch_inverse_classic(&dst[n - WIDTH..], &mut tail_inverses); - - // Second pass. - for i in (WIDTH..n).rev() { - dst[i] = dst[i - WIDTH] * tail_inverses[i % WIDTH]; - tail_inverses[i % WIDTH] *= column[i]; - } - dst[0..WIDTH].copy_from_slice(&tail_inverses); - } -} - -/// Assumes dst is initialized and of the same length as column. -fn batch_inverse_classic(column: &[T], dst: &mut [T]) { - let n = column.len(); - debug_assert!(dst.len() >= n); - - dst[0] = column[0]; - // First pass. - for i in 1..n { - dst[i] = dst[i - 1] * column[i]; - } - - // Inverse cumulative product. - let mut curr_inverse = dst[n - 1].inverse(); - - // Second pass. - for i in (1..n).rev() { - dst[i] = dst[i - 1] * curr_inverse; - curr_inverse *= column[i]; - } - dst[0] = curr_inverse; -} - -pub trait Field: - NumAssign - + Neg - + ComplexConjugate - + Copy - + Default - + Debug - + Display - + PartialOrd - + Ord - + Send - + Sync - + Sized - + FieldExpOps - + Product - + for<'a> Product<&'a Self> - + Sum - + for<'a> Sum<&'a Self> -{ - fn double(&self) -> Self { - (*self) + (*self) - } -} - -/// # Safety -/// -/// Do not use unless you are aware of the endianess in the platform you are compiling for, and the -/// Field element's representation in memory. -// TODO(Ohad): Do not compile on non-le targets. -pub unsafe trait IntoSlice: Sized { - fn into_slice(sl: &[Self]) -> &[T] { - unsafe { - std::slice::from_raw_parts( - sl.as_ptr() as *const T, - std::mem::size_of_val(sl) / std::mem::size_of::(), - ) - } - } -} - -unsafe impl IntoSlice for F {} - -pub trait ComplexConjugate { - /// # Example - /// - /// ``` - /// use stwo_prover::core::fields::m31::P; - /// use stwo_prover::core::fields::qm31::QM31; - /// use stwo_prover::core::fields::ComplexConjugate; - /// - /// let x = QM31::from_u32_unchecked(1, 2, 3, 4); - /// assert_eq!( - /// x.complex_conjugate(), - /// QM31::from_u32_unchecked(1, 2, P - 3, P - 4) - /// ); - /// ``` - fn complex_conjugate(&self) -> Self; -} - -pub trait ExtensionOf: Field + From + NumOps + NumAssignOps { - const EXTENSION_DEGREE: usize; -} - -impl ExtensionOf for F { - const EXTENSION_DEGREE: usize = 1; -} - -#[macro_export] -macro_rules! impl_field { - ($field_name: ty, $field_size: ident) => { - use std::iter::{Product, Sum}; - - use num_traits::{Num, One, Zero}; - use $crate::core::fields::Field; - - impl Num for $field_name { - type FromStrRadixErr = Box; - - fn from_str_radix(_str: &str, _radix: u32) -> Result { - unimplemented!( - "Num::from_str_radix is not implemented for {}", - stringify!($field_name) - ); - } - } - - impl Field for $field_name {} - - impl AddAssign for $field_name { - fn add_assign(&mut self, rhs: Self) { - *self = *self + rhs; - } - } - - impl SubAssign for $field_name { - fn sub_assign(&mut self, rhs: Self) { - *self = *self - rhs; - } - } - - impl MulAssign for $field_name { - fn mul_assign(&mut self, rhs: Self) { - *self = *self * rhs; - } - } - - impl Div for $field_name { - type Output = Self; - - #[allow(clippy::suspicious_arithmetic_impl)] - fn div(self, rhs: Self) -> Self::Output { - self * rhs.inverse() - } - } - - impl DivAssign for $field_name { - fn div_assign(&mut self, rhs: Self) { - *self = *self / rhs; - } - } - - impl Rem for $field_name { - type Output = Self; - - fn rem(self, _rhs: Self) -> Self::Output { - unimplemented!("Rem is not implemented for {}", stringify!($field_name)); - } - } - - impl RemAssign for $field_name { - fn rem_assign(&mut self, _rhs: Self) { - unimplemented!( - "RemAssign is not implemented for {}", - stringify!($field_name) - ); - } - } - - impl Product for $field_name { - fn product(mut iter: I) -> Self - where - I: Iterator, - { - let first = iter.next().unwrap_or_else(Self::one); - iter.fold(first, |a, b| a * b) - } - } - - impl<'a> Product<&'a Self> for $field_name { - fn product(iter: I) -> Self - where - I: Iterator, - { - iter.map(|&v| v).product() - } - } - - impl Sum for $field_name { - fn sum(mut iter: I) -> Self - where - I: Iterator, - { - let first = iter.next().unwrap_or_else(Self::zero); - iter.fold(first, |a, b| a + b) - } - } - - impl<'a> Sum<&'a Self> for $field_name { - fn sum(iter: I) -> Self - where - I: Iterator, - { - iter.map(|&v| v).sum() - } - } - }; -} - -/// Used to extend a field (with characteristic M31) by 2. -#[macro_export] -macro_rules! impl_extension_field { - ($field_name: ident, $extended_field_name: ty) => { - use rand::distributions::{Distribution, Standard}; - use $crate::core::fields::ExtensionOf; - - impl ExtensionOf for $field_name { - const EXTENSION_DEGREE: usize = - <$extended_field_name as ExtensionOf>::EXTENSION_DEGREE * 2; - } - - impl Add for $field_name { - type Output = Self; - - fn add(self, rhs: Self) -> Self::Output { - Self(self.0 + rhs.0, self.1 + rhs.1) - } - } - - impl Neg for $field_name { - type Output = Self; - - fn neg(self) -> Self::Output { - Self(-self.0, -self.1) - } - } - - impl Sub for $field_name { - type Output = Self; - - fn sub(self, rhs: Self) -> Self::Output { - Self(self.0 - rhs.0, self.1 - rhs.1) - } - } - - impl One for $field_name { - fn one() -> Self { - Self( - <$extended_field_name>::one(), - <$extended_field_name>::zero(), - ) - } - } - - impl Zero for $field_name { - fn zero() -> Self { - Self( - <$extended_field_name>::zero(), - <$extended_field_name>::zero(), - ) - } - - fn is_zero(&self) -> bool { - *self == Self::zero() - } - } - - impl Add for $field_name { - type Output = Self; - - fn add(self, rhs: M31) -> Self::Output { - Self(self.0 + rhs, self.1) - } - } - - impl Add<$field_name> for M31 { - type Output = $field_name; - - fn add(self, rhs: $field_name) -> Self::Output { - rhs + self - } - } - - impl Sub for $field_name { - type Output = Self; - - fn sub(self, rhs: M31) -> Self::Output { - Self(self.0 - rhs, self.1) - } - } - - impl Sub<$field_name> for M31 { - type Output = $field_name; - - fn sub(self, rhs: $field_name) -> Self::Output { - -rhs + self - } - } - - impl Mul for $field_name { - type Output = Self; - - fn mul(self, rhs: M31) -> Self::Output { - Self(self.0 * rhs, self.1 * rhs) - } - } - - impl Mul<$field_name> for M31 { - type Output = $field_name; - - fn mul(self, rhs: $field_name) -> Self::Output { - rhs * self - } - } - - impl Div for $field_name { - type Output = Self; - - fn div(self, rhs: M31) -> Self::Output { - Self(self.0 / rhs, self.1 / rhs) - } - } - - impl Div<$field_name> for M31 { - type Output = $field_name; - - #[allow(clippy::suspicious_arithmetic_impl)] - fn div(self, rhs: $field_name) -> Self::Output { - rhs.inverse() * self - } - } - - impl ComplexConjugate for $field_name { - fn complex_conjugate(&self) -> Self { - Self(self.0, -self.1) - } - } - - impl From for $field_name { - fn from(x: M31) -> Self { - Self(x.into(), <$extended_field_name>::zero()) - } - } - - impl AddAssign for $field_name { - fn add_assign(&mut self, rhs: M31) { - *self = *self + rhs; - } - } - - impl SubAssign for $field_name { - fn sub_assign(&mut self, rhs: M31) { - *self = *self - rhs; - } - } - - impl MulAssign for $field_name { - fn mul_assign(&mut self, rhs: M31) { - *self = *self * rhs; - } - } - - impl DivAssign for $field_name { - fn div_assign(&mut self, rhs: M31) { - *self = *self / rhs; - } - } - - impl Rem for $field_name { - type Output = Self; - - fn rem(self, _rhs: M31) -> Self::Output { - unimplemented!("Rem is not implemented for {}", stringify!($field_name)); - } - } - - impl RemAssign for $field_name { - fn rem_assign(&mut self, _rhs: M31) { - unimplemented!( - "RemAssign is not implemented for {}", - stringify!($field_name) - ); - } - } - - impl Distribution<$field_name> for Standard { - // Not intended for cryptographic use. Should only be used in tests and benchmarks. - fn sample(&self, rng: &mut R) -> $field_name { - $field_name(rng.gen(), rng.gen()) - } - } - }; -} - -#[cfg(test)] -mod tests { - use num_traits::Zero; - use rand::rngs::SmallRng; - use rand::{Rng, SeedableRng}; - - use crate::core::fields::m31::M31; - use crate::core::fields::FieldExpOps; - - #[test] - fn test_slice_batch_inverse() { - let mut rng = SmallRng::seed_from_u64(0); - let elements: [M31; 16] = rng.gen(); - let expected = elements.iter().map(|e| e.inverse()).collect::>(); - let mut dst = [M31::zero(); 16]; - - M31::batch_inverse(&elements, &mut dst); - - assert_eq!(expected, dst); - } - - #[test] - #[should_panic] - fn test_slice_batch_inverse_wrong_dst_size() { - let mut rng = SmallRng::seed_from_u64(0); - let elements: [M31; 16] = rng.gen(); - let mut dst = [M31::zero(); 15]; - - M31::batch_inverse(&elements, &mut dst); - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/fields/qm31.rs b/Stwo_wrapper/crates/prover/src/core/fields/qm31.rs deleted file mode 100644 index 6da19a3..0000000 --- a/Stwo_wrapper/crates/prover/src/core/fields/qm31.rs +++ /dev/null @@ -1,195 +0,0 @@ -use std::fmt::{Debug, Display}; -use std::ops::{ - Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign, -}; - -use serde::{Deserialize, Serialize}; - -use super::secure_column::SECURE_EXTENSION_DEGREE; -use super::{ComplexConjugate, FieldExpOps}; -use crate::core::fields::cm31::CM31; -use crate::core::fields::m31::M31; -use crate::{impl_extension_field, impl_field}; - -pub const P4: u128 = 21267647892944572736998860269687930881; // (2 ** 31 - 1) ** 4 -pub const R: CM31 = CM31::from_u32_unchecked(2, 1); - -/// Extension field of CM31. -/// Equivalent to CM31\[x\] over (x^2 - 2 - i) as the irreducible polynomial. -/// Represented as ((a, b), (c, d)) of (a + bi) + (c + di)u. -#[derive(Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)] -pub struct QM31(pub CM31, pub CM31); -pub type SecureField = QM31; - -impl_field!(QM31, P4); -impl_extension_field!(QM31, CM31); - -impl QM31 { - pub const fn from_u32_unchecked(a: u32, b: u32, c: u32, d: u32) -> Self { - Self( - CM31::from_u32_unchecked(a, b), - CM31::from_u32_unchecked(c, d), - ) - } - - pub fn from_m31(a: M31, b: M31, c: M31, d: M31) -> Self { - Self(CM31::from_m31(a, b), CM31::from_m31(c, d)) - } - - pub fn from_m31_array(array: [M31; SECURE_EXTENSION_DEGREE]) -> Self { - Self::from_m31(array[0], array[1], array[2], array[3]) - } - - pub fn to_m31_array(self) -> [M31; SECURE_EXTENSION_DEGREE] { - [self.0 .0, self.0 .1, self.1 .0, self.1 .1] - } - - /// Returns the combined value, given the values of its composing base field polynomials at that - /// point. - pub fn from_partial_evals(evals: [Self; SECURE_EXTENSION_DEGREE]) -> Self { - let mut res = evals[0]; - res += evals[1] * Self::from_u32_unchecked(0, 1, 0, 0); - res += evals[2] * Self::from_u32_unchecked(0, 0, 1, 0); - res += evals[3] * Self::from_u32_unchecked(0, 0, 0, 1); - res - } - - // Note: Adding this as a Mul impl drives rust insane, and it tries to infer Qm31*Qm31 as - // QM31*CM31. - pub fn mul_cm31(self, rhs: CM31) -> Self { - Self(self.0 * rhs, self.1 * rhs) - } -} - -impl Display for QM31 { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "({}) + ({})u", self.0, self.1) - } -} - -impl Debug for QM31 { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "({}) + ({})u", self.0, self.1) - } -} - -impl Mul for QM31 { - type Output = Self; - - fn mul(self, rhs: Self) -> Self::Output { - // (a + bu) * (c + du) = (ac + rbd) + (ad + bc)u. - Self( - self.0 * rhs.0 + R * self.1 * rhs.1, - self.0 * rhs.1 + self.1 * rhs.0, - ) - } -} - -impl From for QM31 { - fn from(value: usize) -> Self { - M31::from(value).into() - } -} - -impl From for QM31 { - fn from(value: u32) -> Self { - M31::from(value).into() - } -} - -impl From for QM31 { - fn from(value: i32) -> Self { - M31::from(value).into() - } -} - -impl TryInto for QM31 { - type Error = (); - - fn try_into(self) -> Result { - if self.1 != CM31::zero() { - return Err(()); - } - self.0.try_into() - } -} - -impl FieldExpOps for QM31 { - fn inverse(&self) -> Self { - assert!(!self.is_zero(), "0 has no inverse"); - // (a + bu)^-1 = (a - bu) / (a^2 - (2+i)b^2). - let b2 = self.1.square(); - let ib2 = CM31(-b2.1, b2.0); - let denom = self.0.square() - (b2 + b2 + ib2); - let denom_inverse = denom.inverse(); - Self(self.0 * denom_inverse, -self.1 * denom_inverse) - } -} - -#[cfg(test)] -#[macro_export] -macro_rules! qm31 { - ($m0:expr, $m1:expr, $m2:expr, $m3:expr) => {{ - use $crate::core::fields::qm31::QM31; - QM31::from_u32_unchecked($m0, $m1, $m2, $m3) - }}; -} - -#[cfg(test)] -mod tests { - use num_traits::One; - use rand::rngs::SmallRng; - use rand::{Rng, SeedableRng}; - - use super::QM31; - use crate::core::fields::m31::P; - use crate::core::fields::{FieldExpOps, IntoSlice}; - use crate::m31; - - #[test] - fn test_inverse() { - let qm = qm31!(1, 2, 3, 4); - let qm_inv = qm.inverse(); - assert_eq!(qm * qm_inv, QM31::one()); - } - - #[test] - fn test_ops() { - let qm0 = qm31!(1, 2, 3, 4); - let qm1 = qm31!(4, 5, 6, 7); - let m = m31!(8); - let qm = QM31::from(m); - let qm0_x_qm1 = qm31!(P - 71, 93, P - 16, 50); - - assert_eq!(qm0 + qm1, qm31!(5, 7, 9, 11)); - assert_eq!(qm1 + m, qm1 + qm); - assert_eq!(qm0 * qm1, qm0_x_qm1); - assert_eq!(qm1 * m, qm1 * qm); - assert_eq!(-qm0, qm31!(P - 1, P - 2, P - 3, P - 4)); - assert_eq!(qm0 - qm1, qm31!(P - 3, P - 3, P - 3, P - 3)); - assert_eq!(qm1 - m, qm1 - qm); - assert_eq!(qm0_x_qm1 / qm1, qm31!(1, 2, 3, 4)); - assert_eq!(qm1 / m, qm1 / qm); - } - - #[test] - fn test_into_slice() { - let mut rng = SmallRng::seed_from_u64(0); - let x = (0..100).map(|_| rng.gen()).collect::>(); - - let slice = QM31::into_slice(&x); - - for i in 0..100 { - let corresponding_sub_slice = &slice[i * 16..(i + 1) * 16]; - assert_eq!( - x[i], - qm31!( - u32::from_le_bytes(corresponding_sub_slice[..4].try_into().unwrap()), - u32::from_le_bytes(corresponding_sub_slice[4..8].try_into().unwrap()), - u32::from_le_bytes(corresponding_sub_slice[8..12].try_into().unwrap()), - u32::from_le_bytes(corresponding_sub_slice[12..16].try_into().unwrap()) - ) - ) - } - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/fields/secure_column.rs b/Stwo_wrapper/crates/prover/src/core/fields/secure_column.rs deleted file mode 100644 index 073d21b..0000000 --- a/Stwo_wrapper/crates/prover/src/core/fields/secure_column.rs +++ /dev/null @@ -1,111 +0,0 @@ -use std::array; -use std::iter::zip; - -use super::m31::BaseField; -use super::qm31::SecureField; -use super::{ExtensionOf, FieldOps}; -use crate::core::backend::{Col, Column, CpuBackend}; - -pub const SECURE_EXTENSION_DEGREE: usize = - >::EXTENSION_DEGREE; - -/// A column major array of `SECURE_EXTENSION_DEGREE` base field columns, that represents a column -/// of secure field element coordinates. -#[derive(Clone, Debug)] -pub struct SecureColumnByCoords> { - pub columns: [Col; SECURE_EXTENSION_DEGREE], -} -impl SecureColumnByCoords { - // TODO(spapini): Remove when we no longer use CircleEvaluation. - pub fn to_vec(&self) -> Vec { - (0..self.len()).map(|i| self.at(i)).collect() - } -} -impl> SecureColumnByCoords { - pub fn at(&self, index: usize) -> SecureField { - SecureField::from_m31_array(std::array::from_fn(|i| self.columns[i].at(index))) - } - - pub fn zeros(len: usize) -> Self { - Self { - columns: std::array::from_fn(|_| Col::::zeros(len)), - } - } - - /// # Safety - pub unsafe fn uninitialized(len: usize) -> Self { - Self { - columns: std::array::from_fn(|_| Col::::uninitialized(len)), - } - } - - pub fn len(&self) -> usize { - self.columns[0].len() - } - - pub fn is_empty(&self) -> bool { - self.columns[0].is_empty() - } - - pub fn to_cpu(&self) -> SecureColumnByCoords { - SecureColumnByCoords { - columns: self.columns.clone().map(|c| c.to_cpu()), - } - } - - pub fn set(&mut self, index: usize, value: SecureField) { - let values = value.to_m31_array(); - #[allow(clippy::needless_range_loop)] - for i in 0..SECURE_EXTENSION_DEGREE { - self.columns[i].set(index, values[i]); - } - } -} - -pub struct SecureColumnByCoordsIter<'a> { - column: &'a SecureColumnByCoords, - index: usize, -} -impl Iterator for SecureColumnByCoordsIter<'_> { - type Item = SecureField; - - fn next(&mut self) -> Option { - if self.index < self.column.len() { - let value = self.column.at(self.index); - self.index += 1; - Some(value) - } else { - None - } - } -} -impl<'a> IntoIterator for &'a SecureColumnByCoords { - type Item = SecureField; - type IntoIter = SecureColumnByCoordsIter<'a>; - - fn into_iter(self) -> Self::IntoIter { - SecureColumnByCoordsIter { - column: self, - index: 0, - } - } -} -impl FromIterator for SecureColumnByCoords { - fn from_iter>(iter: I) -> Self { - let values = iter.into_iter(); - let (lower_bound, _) = values.size_hint(); - let mut columns = array::from_fn(|_| Vec::with_capacity(lower_bound)); - - for value in values { - let coords = value.to_m31_array(); - zip(&mut columns, coords).for_each(|(col, coord)| col.push(coord)); - } - - SecureColumnByCoords { columns } - } -} -impl From> for Vec { - fn from(column: SecureColumnByCoords) -> Self { - column.into_iter().collect() - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/fri.rs b/Stwo_wrapper/crates/prover/src/core/fri.rs deleted file mode 100644 index 0934c77..0000000 --- a/Stwo_wrapper/crates/prover/src/core/fri.rs +++ /dev/null @@ -1,1424 +0,0 @@ -use std::cmp::Reverse; -use std::collections::BTreeMap; -use std::fmt::Debug; -use std::iter::zip; -use std::ops::RangeInclusive; -use itertools::Itertools; -use num_traits::Zero; -use thiserror::Error; -use tracing::{span, Level}; - -use super::backend::CpuBackend; -use super::channel::{Channel, MerkleChannel}; -use super::fields::m31::BaseField; -use super::fields::qm31::SecureField; -use super::fields::secure_column::{SecureColumnByCoords, SECURE_EXTENSION_DEGREE}; -use super::fields::FieldOps; -use super::poly::circle::{CircleEvaluation, PolyOps, SecureEvaluation}; -use super::poly::line::{LineEvaluation, LinePoly}; -use super::poly::twiddles::TwiddleTree; -use super::poly::BitReversedOrder; -// TODO(andrew): Create fri/ directory, move queries.rs there and split this file up. -use super::queries::{Queries, SparseSubCircleDomain}; -use crate::core::circle::Coset; -use crate::core::fft::ibutterfly; -use crate::core::fields::FieldExpOps; -use crate::core::poly::line::LineDomain; -use crate::core::utils::bit_reverse_index; -use crate::core::vcs::ops::{MerkleHasher, MerkleOps}; -use crate::core::vcs::prover::{MerkleDecommitment, MerkleProver}; -use crate::core::vcs::verifier::{MerkleVerificationError, MerkleVerifier}; - -/// FRI proof config -// TODO(andrew): Support different step sizes. -#[derive(Debug, Clone, Copy)] -pub struct FriConfig { - pub log_blowup_factor: u32, - pub log_last_layer_degree_bound: u32, - pub n_queries: usize, - // TODO(andrew): fold_steps. -} - -impl FriConfig { - const LOG_MIN_LAST_LAYER_DEGREE_BOUND: u32 = 0; - const LOG_MAX_LAST_LAYER_DEGREE_BOUND: u32 = 10; - const LOG_LAST_LAYER_DEGREE_BOUND_RANGE: RangeInclusive = - Self::LOG_MIN_LAST_LAYER_DEGREE_BOUND..=Self::LOG_MAX_LAST_LAYER_DEGREE_BOUND; - - const LOG_MIN_BLOWUP_FACTOR: u32 = 1; - const LOG_MAX_BLOWUP_FACTOR: u32 = 16; - const LOG_BLOWUP_FACTOR_RANGE: RangeInclusive = - Self::LOG_MIN_BLOWUP_FACTOR..=Self::LOG_MAX_BLOWUP_FACTOR; - - /// Creates a new FRI configuration. - /// - /// # Panics - /// - /// Panics if: - /// * `log_last_layer_degree_bound` is greater than 10. - /// * `log_blowup_factor` is equal to zero or greater than 16. - pub fn new(log_last_layer_degree_bound: u32, log_blowup_factor: u32, n_queries: usize) -> Self { - assert!(Self::LOG_LAST_LAYER_DEGREE_BOUND_RANGE.contains(&log_last_layer_degree_bound)); - assert!(Self::LOG_BLOWUP_FACTOR_RANGE.contains(&log_blowup_factor)); - Self { - log_blowup_factor, - log_last_layer_degree_bound, - n_queries, - } - } - - fn last_layer_domain_size(&self) -> usize { - 1 << (self.log_last_layer_degree_bound + self.log_blowup_factor) - } -} - -pub trait FriOps: FieldOps + PolyOps + Sized + FieldOps { - /// Folds a degree `d` polynomial into a degree `d/2` polynomial. - /// - /// Let `eval` be a polynomial evaluated on a [LineDomain] `E`, `alpha` be a random field - /// element and `pi(x) = 2x^2 - 1` be the circle's x-coordinate doubling map. This function - /// returns `f' = f0 + alpha * f1` evaluated on `pi(E)` such that `2f(x) = f0(pi(x)) + x * - /// f1(pi(x))`. - /// - /// # Panics - /// - /// Panics if there are less than two evaluations. - fn fold_line( - eval: &LineEvaluation, - alpha: SecureField, - twiddles: &TwiddleTree, - ) -> LineEvaluation; - - /// Folds and accumulates a degree `d` circle polynomial into a degree `d/2` univariate - /// polynomial. - /// - /// Let `src` be the evaluation of a circle polynomial `f` on a - /// [`CircleDomain`] `E`. This function computes evaluations of `f' = f0 - /// + alpha * f1` on the x-coordinates of `E` such that `2f(p) = f0(px) + py * f1(px)`. The - /// evaluations of `f'` are accumulated into `dst` by the formula `dst = dst * alpha^2 + - /// f'`. - /// - /// # Panics - /// - /// Panics if `src` is not double the length of `dst`. - /// - /// [`CircleDomain`]: super::poly::circle::CircleDomain - // TODO(andrew): Make folding factor generic. - // TODO(andrew): Fold directly into FRI layer to prevent allocation. - fn fold_circle_into_line( - dst: &mut LineEvaluation, - src: &SecureEvaluation, - alpha: SecureField, - twiddles: &TwiddleTree, - ); - - /// Decomposes a FRI-space polynomial into a polynomial inside the fft-space and the - /// remainder term. - /// FRI-space: polynomials of total degree n/2. - /// Based on lemma #12 from the CircleStark paper: f(P) = g(P)+ lambda * alternating(P), - /// where lambda is the cosset diff of eval, and g is a polynomial in the fft-space. - fn decompose( - eval: &SecureEvaluation, - ) -> (SecureEvaluation, SecureField); -} -/// A FRI prover that applies the FRI protocol to prove a set of polynomials are of low degree. -pub struct FriProver, MC: MerkleChannel> { - config: FriConfig, - inner_layers: Vec>, - last_layer_poly: LinePoly, - /// Unique sizes of committed columns sorted in descending order. - column_log_sizes: Vec, -} - -impl, MC: MerkleChannel> FriProver { - /// Commits to multiple [CircleEvaluation]s. - /// - /// `columns` must be provided in descending order by size. - /// - /// Mixed degree STARKs involve polynomials evaluated on multiple domains of different size. - /// Combining evaluations on different sized domains into an evaluation of a single polynomial - /// on a single domain for the purpose of commitment is inefficient. Instead, commit to multiple - /// polynomials so combining of evaluations can be taken care of efficiently at the appropriate - /// FRI layer. All evaluations must be taken over canonic [`CircleDomain`]s. - /// - /// # Panics - /// - /// Panics if: - /// * `columns` is empty or not sorted in ascending order by domain size. - /// * An evaluation is not from a sufficiently low degree circle polynomial. - /// * An evaluation's domain is smaller than the last layer. - /// * An evaluation's domain is not a canonic circle domain. - /// - /// [`CircleDomain`]: super::poly::circle::CircleDomain - // TODO(andrew): Add docs for all evaluations needing to be from canonic domains. - pub fn commit( - channel: &mut MC::C, - config: FriConfig, - columns: &[SecureEvaluation], - twiddles: &TwiddleTree, - ) -> Self { - let _span = span!(Level::INFO, "FRI commitment").entered(); - assert!(!columns.is_empty(), "no columns"); - assert!(columns.is_sorted_by_key(|e| Reverse(e.len())), "not sorted"); - assert!(columns.iter().all(|e| e.domain.is_canonic()), "not canonic"); - let (inner_layers, last_layer_evaluation) = - Self::commit_inner_layers(channel, config, columns, twiddles); - let last_layer_poly = Self::commit_last_layer(channel, config, last_layer_evaluation); - - let column_log_sizes = columns - .iter() - .map(|e| e.domain.log_size()) - .dedup() - .collect(); - Self { - config, - inner_layers, - last_layer_poly, - column_log_sizes, - } - } - - /// Builds and commits to the inner FRI layers (all layers except the last layer). - /// - /// All `columns` must be provided in descending order by size. - /// - /// Returns all inner layers and the evaluation of the last layer. - fn commit_inner_layers( - channel: &mut MC::C, - config: FriConfig, - columns: &[SecureEvaluation], - twiddles: &TwiddleTree, - ) -> (Vec>, LineEvaluation) { - // Returns the length of the [LineEvaluation] a [CircleEvaluation] gets folded into. - let folded_len = - |e: &SecureEvaluation| e.len() >> CIRCLE_TO_LINE_FOLD_STEP; - - let first_layer_size = folded_len(&columns[0]); - let first_layer_domain = LineDomain::new(Coset::half_odds(first_layer_size.ilog2())); - let mut layer_evaluation = LineEvaluation::new_zero(first_layer_domain); - - let mut columns = columns.iter().peekable(); - - let mut layers = Vec::new(); - - // Circle polynomials can all be folded with the same alpha. - let circle_poly_alpha = channel.draw_felt(); - - while layer_evaluation.len() > config.last_layer_domain_size() { - // Check for any columns (circle poly evaluations) that should be combined. - while let Some(column) = columns.next_if(|c| folded_len(c) == layer_evaluation.len()) { - B::fold_circle_into_line( - &mut layer_evaluation, - column, - circle_poly_alpha, - twiddles, - ); - } - - let layer = FriLayerProver::new(layer_evaluation); - MC::mix_root(channel, layer.merkle_tree.root()); - let folding_alpha = channel.draw_felt(); - let folded_layer_evaluation = B::fold_line(&layer.evaluation, folding_alpha, twiddles); - - layer_evaluation = folded_layer_evaluation; - layers.push(layer); - } - - // Check all columns have been consumed. - assert!(columns.is_empty()); - - (layers, layer_evaluation) - } - - /// Builds and commits to the last layer. - /// - /// The layer is committed to by sending the verifier all the coefficients of the remaining - /// polynomial. - /// - /// # Panics - /// - /// Panics if: - /// * The evaluation domain size exceeds the maximum last layer domain size. - /// * The evaluation is not of sufficiently low degree. - fn commit_last_layer( - channel: &mut MC::C, - config: FriConfig, - evaluation: LineEvaluation, - ) -> LinePoly { - assert_eq!(evaluation.len(), config.last_layer_domain_size()); - - let evaluation = evaluation.to_cpu(); - let mut coeffs = evaluation.interpolate().into_ordered_coefficients(); - - let last_layer_degree_bound = 1 << config.log_last_layer_degree_bound; - let zeros = coeffs.split_off(last_layer_degree_bound); - assert!(zeros.iter().all(SecureField::is_zero), "invalid degree"); - - let last_layer_poly = LinePoly::from_ordered_coefficients(coeffs); - channel.mix_felts(&last_layer_poly); - - last_layer_poly - } - - /// Generates a FRI proof and returns it with the opening positions for the committed columns. - /// - /// Returned column opening positions are mapped by their log size. - pub fn decommit( - self, - channel: &mut MC::C, - ) -> (FriProof, BTreeMap) { - let max_column_log_size = self.column_log_sizes[0]; - let queries = Queries::generate(channel, max_column_log_size, self.config.n_queries); - let positions = get_opening_positions(&queries, &self.column_log_sizes); - let proof = self.decommit_on_queries(&queries); - (proof, positions) - } - - /// # Panics - /// - /// Panics if the queries were sampled on the wrong domain size. - fn decommit_on_queries(self, queries: &Queries) -> FriProof { - let max_column_log_size = self.column_log_sizes[0]; - assert_eq!(queries.log_domain_size, max_column_log_size); - let first_layer_queries = queries.fold(CIRCLE_TO_LINE_FOLD_STEP); - let inner_layers = self - .inner_layers - .into_iter() - .scan(first_layer_queries, |layer_queries, layer| { - let layer_proof = layer.decommit(layer_queries); - *layer_queries = layer_queries.fold(FOLD_STEP); - Some(layer_proof) - }) - .collect(); - - let last_layer_poly = self.last_layer_poly; - - FriProof { - inner_layers, - last_layer_poly, - } - } -} - -pub struct FriVerifier { - config: FriConfig, - /// Alpha used to fold all circle polynomials to univariate polynomials. - circle_poly_alpha: SecureField, - /// Domain size queries should be sampled from. - expected_query_log_domain_size: u32, - /// The list of degree bounds of all committed circle polynomials. - column_bounds: Vec, - inner_layers: Vec>, - last_layer_domain: LineDomain, - last_layer_poly: LinePoly, - /// The queries used for decommitment. Initialized when calling - /// [`FriVerifier::column_opening_positions`]. - queries: Option, -} - -impl FriVerifier { - /// Verifies the commitment stage of FRI. - /// - /// `column_bounds` should be the committed circle polynomial degree bounds in descending order. - /// - /// # Errors - /// - /// An `Err` will be returned if: - /// * The proof contains an invalid number of FRI layers. - /// * The degree of the last layer polynomial is too high. - /// - /// # Panics - /// - /// Panics if: - /// * There are no degree bounds. - /// * The degree bounds are not sorted in descending order. - /// * A degree bound is less than or equal to the last layer's degree bound. - pub fn commit( - channel: &mut MC::C, - config: FriConfig, - proof: FriProof, - column_bounds: Vec, - ) -> Result { - assert!(column_bounds.is_sorted_by_key(|b| Reverse(*b))); - - let max_column_bound = column_bounds[0]; - let expected_query_log_domain_size = - max_column_bound.log_degree_bound + config.log_blowup_factor; - - // Circle polynomials can all be folded with the same alpha. - let circle_poly_alpha = channel.draw_felt(); - - let mut inner_layers = Vec::new(); - let mut layer_bound = max_column_bound.fold_to_line(); - let mut layer_domain = LineDomain::new(Coset::half_odds( - layer_bound.log_degree_bound + config.log_blowup_factor, - )); - - for (layer_index, proof) in proof.inner_layers.into_iter().enumerate() { - MC::mix_root(channel, proof.commitment); - - let folding_alpha = channel.draw_felt(); - - inner_layers.push(FriLayerVerifier { - degree_bound: layer_bound, - domain: layer_domain, - folding_alpha, - layer_index, - proof, - }); - - layer_bound = layer_bound - .fold(FOLD_STEP) - .ok_or(FriVerificationError::InvalidNumFriLayers)?; - layer_domain = layer_domain.double(); - } - - if layer_bound.log_degree_bound != config.log_last_layer_degree_bound { - return Err(FriVerificationError::InvalidNumFriLayers); - } - - let last_layer_domain = layer_domain; - let last_layer_poly = proof.last_layer_poly; - - if last_layer_poly.len() > (1 << config.log_last_layer_degree_bound) { - return Err(FriVerificationError::LastLayerDegreeInvalid); - } - - channel.mix_felts(&last_layer_poly); - - Ok(Self { - config, - circle_poly_alpha, - column_bounds, - expected_query_log_domain_size, - inner_layers, - last_layer_domain, - last_layer_poly, - queries: None, - }) - } - - /// Verifies the decommitment stage of FRI. - /// - /// The decommitment values need to be provided in the same order as their commitment. - /// - /// # Panics - /// - /// Panics if: - /// * The queries were not yet sampled. - /// * The queries were sampled on the wrong domain size. - /// * There aren't the same number of decommitted values as degree bounds. - // TODO(andrew): Finish docs. - pub fn decommit( - mut self, - decommitted_values: Vec, - ) -> Result<(), FriVerificationError> { - let queries = self.queries.take().expect("queries not sampled"); - self.decommit_on_queries(&queries, decommitted_values) - } - - fn decommit_on_queries( - self, - queries: &Queries, - decommitted_values: Vec, - ) -> Result<(), FriVerificationError> { - assert_eq!(queries.log_domain_size, self.expected_query_log_domain_size); - assert_eq!(decommitted_values.len(), self.column_bounds.len()); - - let (last_layer_queries, last_layer_query_evals) = - self.decommit_inner_layers(queries, decommitted_values)?; - - self.decommit_last_layer(last_layer_queries, last_layer_query_evals) - } - - /// Verifies all inner layer decommitments. - /// - /// Returns the queries and query evaluations needed for verifying the last FRI layer. - fn decommit_inner_layers( - &self, - queries: &Queries, - decommitted_values: Vec, - ) -> Result<(Queries, Vec), FriVerificationError> { - let circle_poly_alpha = self.circle_poly_alpha; - let circle_poly_alpha_sq = circle_poly_alpha * circle_poly_alpha; - - let mut decommitted_values = decommitted_values.into_iter(); - let mut column_bounds = self.column_bounds.iter().copied().peekable(); - let mut layer_queries = queries.fold(CIRCLE_TO_LINE_FOLD_STEP); - let mut layer_query_evals = vec![SecureField::zero(); layer_queries.len()]; - - for layer in self.inner_layers.iter() { - // Check for column evals that need to folded into this layer. - while column_bounds - .next_if(|b| b.fold_to_line() == layer.degree_bound) - .is_some() - { - let sparse_evaluation = decommitted_values.next().unwrap(); - let folded_evals = sparse_evaluation.fold(circle_poly_alpha); - assert_eq!(folded_evals.len(), layer_query_evals.len()); - - for (layer_eval, folded_eval) in zip(&mut layer_query_evals, folded_evals) { - *layer_eval = *layer_eval * circle_poly_alpha_sq + folded_eval; - } - } - - (layer_queries, layer_query_evals) = - layer.verify_and_fold(layer_queries, layer_query_evals)?; - } - - // Check all values have been consumed. - assert!(column_bounds.is_empty()); - assert!(decommitted_values.is_empty()); - - Ok((layer_queries, layer_query_evals)) - } - - /// Verifies the last layer. - fn decommit_last_layer( - self, - queries: Queries, - query_evals: Vec, - ) -> Result<(), FriVerificationError> { - let Self { - last_layer_domain: domain, - last_layer_poly, - .. - } = self; - - for (&query, query_eval) in zip(&*queries, query_evals) { - let x = domain.at(bit_reverse_index(query, domain.log_size())); - - if query_eval != last_layer_poly.eval_at_point(x.into()) { - return Err(FriVerificationError::LastLayerEvaluationsInvalid); - } - } - - Ok(()) - } - - /// Samples queries and returns the opening positions for each unique column size. - /// - /// The order of the opening positions corresponds to the order of the column commitment. - pub fn column_query_positions( - &mut self, - channel: &mut MC::C, - ) -> BTreeMap { - let column_log_sizes = self - .column_bounds - .iter() - .dedup() - .map(|b| b.log_degree_bound + self.config.log_blowup_factor) - .collect_vec(); - let queries = Queries::generate(channel, column_log_sizes[0], self.config.n_queries); - let positions = get_opening_positions(&queries, &column_log_sizes); - self.queries = Some(queries); - positions - } -} - -/// Returns the column opening positions needed for verification. -/// -/// The column log sizes must be unique and in descending order. Returned -/// column opening positions are mapped by their log size. -fn get_opening_positions( - queries: &Queries, - column_log_sizes: &[u32], -) -> BTreeMap { - let mut prev_log_size = column_log_sizes[0]; - assert!(prev_log_size == queries.log_domain_size); - let mut prev_queries = queries.clone(); - let mut positions = BTreeMap::new(); - positions.insert(prev_log_size, prev_queries.opening_positions(FOLD_STEP)); - for log_size in column_log_sizes.iter().skip(1) { - let n_folds = prev_log_size - log_size; - let queries = prev_queries.fold(n_folds); - positions.insert(*log_size, queries.opening_positions(FOLD_STEP)); - prev_log_size = *log_size; - prev_queries = queries; - } - positions -} - -#[derive(Clone, Copy, Debug, Error)] -pub enum FriVerificationError { - #[error("proof contains an invalid number of FRI layers")] - InvalidNumFriLayers, - #[error("queries do not resolve to their commitment in layer {layer}")] - InnerLayerCommitmentInvalid { - layer: usize, - error: MerkleVerificationError, - }, - #[error("evaluations are invalid in layer {layer}")] - InnerLayerEvaluationsInvalid { layer: usize }, - #[error("degree of last layer is invalid")] - LastLayerDegreeInvalid, - #[error("evaluations in the last layer are invalid")] - LastLayerEvaluationsInvalid, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -pub struct CirclePolyDegreeBound { - log_degree_bound: u32, -} - -impl CirclePolyDegreeBound { - pub fn new(log_degree_bound: u32) -> Self { - Self { log_degree_bound } - } - - /// Maps a circle polynomial's degree bound to the degree bound of the univariate (line) - /// polynomial it gets folded into. - fn fold_to_line(&self) -> LinePolyDegreeBound { - LinePolyDegreeBound { - log_degree_bound: self.log_degree_bound - CIRCLE_TO_LINE_FOLD_STEP, - } - } -} - -impl PartialOrd for CirclePolyDegreeBound { - fn partial_cmp(&self, other: &LinePolyDegreeBound) -> Option { - Some(self.log_degree_bound.cmp(&other.log_degree_bound)) - } -} - -impl PartialEq for CirclePolyDegreeBound { - fn eq(&self, other: &LinePolyDegreeBound) -> bool { - self.log_degree_bound == other.log_degree_bound - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -struct LinePolyDegreeBound { - log_degree_bound: u32, -} - -impl LinePolyDegreeBound { - /// Returns [None] if the unfolded degree bound is smaller than the folding factor. - fn fold(self, n_folds: u32) -> Option { - if self.log_degree_bound < n_folds { - return None; - } - - let log_degree_bound = self.log_degree_bound - n_folds; - Some(Self { log_degree_bound }) - } -} - -/// A FRI proof. -#[derive(Debug)] -pub struct FriProof { - pub inner_layers: Vec>, - pub last_layer_poly: LinePoly, -} - -/// Number of folds for univariate polynomials. -// TODO(andrew): Support different step sizes. -pub const FOLD_STEP: u32 = 1; - -/// Number of folds when folding a circle polynomial to univariate polynomial. -pub const CIRCLE_TO_LINE_FOLD_STEP: u32 = 1; - -/// Stores a subset of evaluations in a fri layer with their corresponding merkle decommitments. -/// -/// The subset corresponds to the set of evaluations needed by a FRI verifier. -#[derive(Debug)] -pub struct FriLayerProof { - /// The subset stored corresponds to the set of evaluations the verifier doesn't have but needs - /// to fold and verify the merkle decommitment. - pub evals_subset: Vec, - pub decommitment: MerkleDecommitment, - pub commitment: H::Hash, -} - -struct FriLayerVerifier { - degree_bound: LinePolyDegreeBound, - domain: LineDomain, - folding_alpha: SecureField, - layer_index: usize, - proof: FriLayerProof, -} - -impl FriLayerVerifier { - /// Verifies the layer's merkle decommitment and returns the the folded queries and query evals. - /// - /// # Errors - /// - /// An `Err` will be returned if: - /// * The proof doesn't store enough evaluations. - /// * The merkle decommitment is invalid. - /// - /// # Panics - /// - /// Panics if the number of queries doesn't match the number of evals. - fn verify_and_fold( - &self, - queries: Queries, - evals_at_queries: Vec, - ) -> Result<(Queries, Vec), FriVerificationError> { - let decommitment = self.proof.decommitment.clone(); - let commitment = self.proof.commitment; - - // Extract the evals needed for decommitment and folding. - let sparse_evaluation = self.extract_evaluation(&queries, &evals_at_queries)?; - - // TODO: When leaf values are removed from the decommitment, also remove this block. - let actual_decommitment_evals: SecureColumnByCoords = sparse_evaluation - .subline_evals - .iter() - .flat_map(|e| e.values.into_iter()) - .collect(); - - let folded_queries = queries.fold(FOLD_STEP); - - // Positions of all the decommitment evals. - let decommitment_positions = folded_queries - .iter() - .flat_map(|folded_query| { - let start = folded_query << FOLD_STEP; - let end = start + (1 << FOLD_STEP); - start..end - }) - .collect::>(); - - let merkle_verifier = MerkleVerifier::new( - commitment, - vec![self.domain.log_size(); SECURE_EXTENSION_DEGREE], - ); - // TODO(spapini): Propagate error. - merkle_verifier - .verify( - [(self.domain.log_size(), decommitment_positions)] - .into_iter() - .collect(), - actual_decommitment_evals.columns.to_vec(), - decommitment, - ) - .map_err(|e| FriVerificationError::InnerLayerCommitmentInvalid { - layer: self.layer_index, - error: e, - })?; - - let evals_at_folded_queries = sparse_evaluation.fold(self.folding_alpha); - - Ok((folded_queries, evals_at_folded_queries)) - } - - /// Returns the evaluations needed for decommitment. - /// - /// # Errors - /// - /// Returns an `Err` if the proof doesn't store enough evaluations. - /// - /// # Panics - /// - /// Panics if the number of queries doesn't match the number of evals. - fn extract_evaluation( - &self, - queries: &Queries, - evals_at_queries: &[SecureField], - ) -> Result { - // Evals provided by the verifier. - let mut evals_at_queries = evals_at_queries.iter().copied(); - - // Evals stored in the proof. - let mut proof_evals = self.proof.evals_subset.iter().copied(); - - let mut all_subline_evals = Vec::new(); - - // Group queries by the subline they reside in. - for subline_queries in queries.group_by(|a, b| a >> FOLD_STEP == b >> FOLD_STEP) { - let subline_start = (subline_queries[0] >> FOLD_STEP) << FOLD_STEP; - let subline_end = subline_start + (1 << FOLD_STEP); - - let mut subline_evals = Vec::new(); - let mut subline_queries = subline_queries.iter().peekable(); - - // Insert the evals. - for eval_position in subline_start..subline_end { - let eval = match subline_queries.next_if_eq(&&eval_position) { - Some(_) => evals_at_queries.next().unwrap(), - None => proof_evals.next().ok_or( - FriVerificationError::InnerLayerEvaluationsInvalid { - layer: self.layer_index, - }, - )?, - }; - - subline_evals.push(eval); - } - - // Construct the domain. - // TODO(andrew): Create a constructor for LineDomain. - let subline_initial_index = bit_reverse_index(subline_start, self.domain.log_size()); - let subline_initial = self.domain.coset().index_at(subline_initial_index); - let subline_domain = LineDomain::new(Coset::new(subline_initial, FOLD_STEP)); - - all_subline_evals.push(LineEvaluation::new( - subline_domain, - subline_evals.into_iter().collect(), - )); - } - - // Check all proof evals have been consumed. - if !proof_evals.is_empty() { - return Err(FriVerificationError::InnerLayerEvaluationsInvalid { - layer: self.layer_index, - }); - } - - Ok(SparseLineEvaluation::new(all_subline_evals)) - } -} - -/// A FRI layer comprises of a merkle tree that commits to evaluations of a polynomial. -/// -/// The polynomial evaluations are viewed as evaluation of a polynomial on multiple distinct cosets -/// of size two. Each leaf of the merkle tree commits to a single coset evaluation. -// TODO(andrew): Support different step sizes. -struct FriLayerProver, H: MerkleHasher> { - evaluation: LineEvaluation, - merkle_tree: MerkleProver, -} - -impl, H: MerkleHasher> FriLayerProver { - fn new(evaluation: LineEvaluation) -> Self { - // TODO(spapini): Commit on slice. - // TODO(spapini): Merkle tree in backend. - let merkle_tree = MerkleProver::commit(evaluation.values.columns.iter().collect_vec()); - #[allow(unreachable_code)] - FriLayerProver { - evaluation, - merkle_tree, - } - } - - /// Generates a decommitment of the subline evaluations at the specified positions. - fn decommit(self, queries: &Queries) -> FriLayerProof { - let mut decommit_positions = Vec::new(); - let mut evals_subset = Vec::new(); - - // Group queries by the subline they reside in. - // TODO(andrew): Explain what a "subline" is at the top of the module. - for query_group in queries.group_by(|a, b| a >> FOLD_STEP == b >> FOLD_STEP) { - let subline_start = (query_group[0] >> FOLD_STEP) << FOLD_STEP; - let subline_end = subline_start + (1 << FOLD_STEP); - - let mut subline_queries = query_group.iter().peekable(); - - for eval_position in subline_start..subline_end { - // Add decommitment position. - decommit_positions.push(eval_position); - - // Skip evals the verifier can calculate. - if subline_queries.next_if_eq(&&eval_position).is_some() { - continue; - } - - let eval = self.evaluation.values.at(eval_position); - evals_subset.push(eval); - } - } - - let commitment = self.merkle_tree.root(); - // TODO(spapini): Use _evals. - let (_evals, decommitment) = self.merkle_tree.decommit( - [(self.evaluation.len().ilog2(), decommit_positions)] - .into_iter() - .collect(), - self.evaluation.values.columns.iter().collect_vec(), - ); - - FriLayerProof { - evals_subset, - decommitment, - commitment, - } - } -} - -/// Holds a foldable subset of circle polynomial evaluations. -#[derive(Debug, Clone)] -pub struct SparseCircleEvaluation { - subcircle_evals: Vec>, -} - -impl SparseCircleEvaluation { - /// # Panics - /// - /// Panics if the evaluation domain sizes don't equal the folding factor. - pub fn new( - subcircle_evals: Vec>, - ) -> Self { - let folding_factor = 1 << CIRCLE_TO_LINE_FOLD_STEP; - assert!(subcircle_evals.iter().all(|e| e.len() == folding_factor)); - Self { subcircle_evals } - } - - fn fold(self, alpha: SecureField) -> Vec { - self.subcircle_evals - .into_iter() - .map(|e| { - let buffer_domain = LineDomain::new(e.domain.half_coset); - let mut buffer = LineEvaluation::new_zero(buffer_domain); - fold_circle_into_line( - &mut buffer, - &SecureEvaluation::new(e.domain, e.values.into_iter().collect()), - alpha, - ); - buffer.values.at(0) - }) - .collect() - } -} - -impl<'a> IntoIterator for &'a mut SparseCircleEvaluation { - type Item = &'a mut CircleEvaluation; - type IntoIter = - std::slice::IterMut<'a, CircleEvaluation>; - - fn into_iter(self) -> Self::IntoIter { - self.subcircle_evals.iter_mut() - } -} - -/// Holds a small foldable subset of univariate SecureField polynomial evaluations. -/// Evaluation is held at the CPU backend. -#[derive(Debug, Clone)] -struct SparseLineEvaluation { - subline_evals: Vec>, -} - -impl SparseLineEvaluation { - /// # Panics - /// - /// Panics if the evaluation domain sizes don't equal the folding factor. - fn new(subline_evals: Vec>) -> Self { - let folding_factor = 1 << FOLD_STEP; - assert!(subline_evals.iter().all(|e| e.len() == folding_factor)); - Self { subline_evals } - } - - fn fold(self, alpha: SecureField) -> Vec { - self.subline_evals - .into_iter() - .map(|e| fold_line(&e, alpha).values.at(0)) - .collect() - } -} - -/// Folds a degree `d` polynomial into a degree `d/2` polynomial. -/// See [`FriOps::fold_line`]. -pub fn fold_line( - eval: &LineEvaluation, - alpha: SecureField, -) -> LineEvaluation { - let n = eval.len(); - assert!(n >= 2, "Evaluation too small"); - - let domain = eval.domain(); - - let folded_values = eval - .values - .into_iter() - .array_chunks() - .enumerate() - .map(|(i, [f_x, f_neg_x])| { - // TODO(andrew): Inefficient. Update when domain twiddles get stored in a buffer. - let x = domain.at(bit_reverse_index(i << FOLD_STEP, domain.log_size())); - - let (mut f0, mut f1) = (f_x, f_neg_x); - ibutterfly(&mut f0, &mut f1, x.inverse()); - f0 + alpha * f1 - }) - .collect(); - - LineEvaluation::new(domain.double(), folded_values) -} - -/// Folds and accumulates a degree `d` circle polynomial into a degree `d/2` univariate -/// polynomial. -/// See [`FriOps::fold_circle_into_line`]. -pub fn fold_circle_into_line( - dst: &mut LineEvaluation, - src: &SecureEvaluation, - alpha: SecureField, -) { - assert_eq!(src.len() >> CIRCLE_TO_LINE_FOLD_STEP, dst.len()); - - let domain = src.domain; - let alpha_sq = alpha * alpha; - - src.into_iter() - .array_chunks() - .enumerate() - .for_each(|(i, [f_p, f_neg_p])| { - // TODO(andrew): Inefficient. Update when domain twiddles get stored in a buffer. - let p = domain.at(bit_reverse_index( - i << CIRCLE_TO_LINE_FOLD_STEP, - domain.log_size(), - )); - - // Calculate `f0(px)` and `f1(px)` such that `2f(p) = f0(px) + py * f1(px)`. - let (mut f0_px, mut f1_px) = (f_p, f_neg_p); - ibutterfly(&mut f0_px, &mut f1_px, p.y.inverse()); - let f_prime = alpha * f1_px + f0_px; - - dst.values.set(i, dst.values.at(i) * alpha_sq + f_prime); - }); -} - -#[cfg(test)] -mod tests { - use std::iter::zip; - - use itertools::Itertools; - use num_traits::{One, Zero}; - - use super::{get_opening_positions, FriVerificationError, SparseCircleEvaluation}; - use crate::core::backend::cpu::{CpuCircleEvaluation, CpuCirclePoly}; - use crate::core::backend::{ColumnOps, CpuBackend}; - use crate::core::circle::{CirclePointIndex, Coset}; - use crate::core::fields::m31::BaseField; - use crate::core::fields::qm31::SecureField; - use crate::core::fields::Field; - use crate::core::fri::{ - fold_circle_into_line, fold_line, CirclePolyDegreeBound, FriConfig, - CIRCLE_TO_LINE_FOLD_STEP, - }; - use crate::core::poly::circle::{CircleDomain, PolyOps, SecureEvaluation}; - use crate::core::poly::line::{LineDomain, LineEvaluation, LinePoly}; - use crate::core::poly::{BitReversedOrder, NaturalOrder}; - use crate::core::queries::{Queries, SparseSubCircleDomain}; - use crate::core::test_utils::test_channel; - use crate::core::utils::bit_reverse_index; - use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; - - /// Default blowup factor used for tests. - const LOG_BLOWUP_FACTOR: u32 = 2; - - type FriProver = super::FriProver; - type FriVerifier = super::FriVerifier; - - #[test] - fn fold_line_works() { - const DEGREE: usize = 8; - // Coefficients are bit-reversed. - let even_coeffs: [SecureField; DEGREE / 2] = [1, 2, 1, 3] - .map(BaseField::from_u32_unchecked) - .map(SecureField::from); - let odd_coeffs: [SecureField; DEGREE / 2] = [3, 5, 4, 1] - .map(BaseField::from_u32_unchecked) - .map(SecureField::from); - let poly = LinePoly::new([even_coeffs, odd_coeffs].concat()); - let even_poly = LinePoly::new(even_coeffs.to_vec()); - let odd_poly = LinePoly::new(odd_coeffs.to_vec()); - let alpha = BaseField::from_u32_unchecked(19283).into(); - let domain = LineDomain::new(Coset::half_odds(DEGREE.ilog2())); - let drp_domain = domain.double(); - let mut values = domain - .iter() - .map(|p| poly.eval_at_point(p.into())) - .collect(); - CpuBackend::bit_reverse_column(&mut values); - let evals = LineEvaluation::new(domain, values.into_iter().collect()); - - let drp_evals = fold_line(&evals, alpha); - let mut drp_evals = drp_evals.values.into_iter().collect_vec(); - CpuBackend::bit_reverse_column(&mut drp_evals); - - assert_eq!(drp_evals.len(), DEGREE / 2); - for (i, (&drp_eval, x)) in zip(&drp_evals, drp_domain).enumerate() { - let f_e: SecureField = even_poly.eval_at_point(x.into()); - let f_o: SecureField = odd_poly.eval_at_point(x.into()); - assert_eq!(drp_eval, (f_e + alpha * f_o).double(), "mismatch at {i}"); - } - } - - #[test] - fn fold_circle_to_line_works() { - const LOG_DEGREE: u32 = 4; - let circle_evaluation = polynomial_evaluation(LOG_DEGREE, LOG_BLOWUP_FACTOR); - let alpha = SecureField::one(); - let folded_domain = LineDomain::new(circle_evaluation.domain.half_coset); - - let mut folded_evaluation = LineEvaluation::new_zero(folded_domain); - fold_circle_into_line(&mut folded_evaluation, &circle_evaluation, alpha); - - assert_eq!( - log_degree_bound(folded_evaluation), - LOG_DEGREE - CIRCLE_TO_LINE_FOLD_STEP - ); - } - - #[test] - #[should_panic = "invalid degree"] - fn committing_high_degree_polynomial_fails() { - const LOG_EXPECTED_BLOWUP_FACTOR: u32 = LOG_BLOWUP_FACTOR; - const LOG_INVALID_BLOWUP_FACTOR: u32 = LOG_BLOWUP_FACTOR - 1; - let config = FriConfig::new(2, LOG_EXPECTED_BLOWUP_FACTOR, 3); - let evaluation = polynomial_evaluation(6, LOG_INVALID_BLOWUP_FACTOR); - - FriProver::commit( - &mut test_channel(), - config, - &[evaluation.clone()], - &CpuBackend::precompute_twiddles(evaluation.domain.half_coset), - ); - } - - #[test] - #[should_panic = "not canonic"] - fn committing_evaluation_from_invalid_domain_fails() { - let invalid_domain = CircleDomain::new(Coset::new(CirclePointIndex::generator(), 3)); - assert!(!invalid_domain.is_canonic(), "must be an invalid domain"); - let evaluation = SecureEvaluation::new( - invalid_domain, - vec![SecureField::one(); 1 << 4].into_iter().collect(), - ); - - FriProver::commit( - &mut test_channel(), - FriConfig::new(2, 2, 3), - &[evaluation.clone()], - &CpuBackend::precompute_twiddles(evaluation.domain.half_coset), - ); - } - - #[test] - fn valid_proof_passes_verification() -> Result<(), FriVerificationError> { - const LOG_DEGREE: u32 = 3; - let evaluation = polynomial_evaluation(LOG_DEGREE, LOG_BLOWUP_FACTOR); - let log_domain_size = evaluation.domain.log_size(); - let queries = Queries::from_positions(vec![5], log_domain_size); - let config = FriConfig::new(1, LOG_BLOWUP_FACTOR, queries.len()); - let decommitment_value = query_polynomial(&evaluation, &queries); - let prover = FriProver::commit( - &mut test_channel(), - config, - &[evaluation.clone()], - &CpuBackend::precompute_twiddles(evaluation.domain.half_coset), - ); - let proof = prover.decommit_on_queries(&queries); - let bound = vec![CirclePolyDegreeBound::new(LOG_DEGREE)]; - let verifier = FriVerifier::commit(&mut test_channel(), config, proof, bound).unwrap(); - - verifier.decommit_on_queries(&queries, vec![decommitment_value]) - } - - #[test] - fn valid_proof_with_constant_last_layer_passes_verification() -> Result<(), FriVerificationError> - { - const LOG_DEGREE: u32 = 3; - const LAST_LAYER_LOG_BOUND: u32 = 0; - let evaluation = polynomial_evaluation(LOG_DEGREE, LOG_BLOWUP_FACTOR); - let log_domain_size = evaluation.domain.log_size(); - let queries = Queries::from_positions(vec![5], log_domain_size); - let config = FriConfig::new(LAST_LAYER_LOG_BOUND, LOG_BLOWUP_FACTOR, queries.len()); - let decommitment_value = query_polynomial(&evaluation, &queries); - let prover = FriProver::commit( - &mut test_channel(), - config, - &[evaluation.clone()], - &CpuBackend::precompute_twiddles(evaluation.domain.half_coset), - ); - let proof = prover.decommit_on_queries(&queries); - let bound = vec![CirclePolyDegreeBound::new(LOG_DEGREE)]; - let verifier = FriVerifier::commit(&mut test_channel(), config, proof, bound).unwrap(); - - verifier.decommit_on_queries(&queries, vec![decommitment_value]) - } - - #[test] - fn valid_mixed_degree_proof_passes_verification() -> Result<(), FriVerificationError> { - const LOG_DEGREES: [u32; 3] = [6, 5, 4]; - let evaluations = LOG_DEGREES.map(|log_d| polynomial_evaluation(log_d, LOG_BLOWUP_FACTOR)); - let log_domain_size = evaluations[0].domain.log_size(); - let queries = Queries::from_positions(vec![7, 70], log_domain_size); - let config = FriConfig::new(2, LOG_BLOWUP_FACTOR, queries.len()); - let prover = FriProver::commit( - &mut test_channel(), - config, - &evaluations, - &CpuBackend::precompute_twiddles(evaluations[0].domain.half_coset), - ); - let decommitment_values = evaluations.map(|p| query_polynomial(&p, &queries)).to_vec(); - let proof = prover.decommit_on_queries(&queries); - let bounds = LOG_DEGREES.map(CirclePolyDegreeBound::new).to_vec(); - let verifier = FriVerifier::commit(&mut test_channel(), config, proof, bounds).unwrap(); - - verifier.decommit_on_queries(&queries, decommitment_values) - } - - #[test] - fn valid_mixed_degree_end_to_end_proof_passes_verification() -> Result<(), FriVerificationError> - { - const LOG_DEGREES: [u32; 3] = [6, 5, 4]; - let evaluations = LOG_DEGREES.map(|log_d| polynomial_evaluation(log_d, LOG_BLOWUP_FACTOR)); - let config = FriConfig::new(2, LOG_BLOWUP_FACTOR, 3); - let prover = FriProver::commit( - &mut test_channel(), - config, - &evaluations, - &CpuBackend::precompute_twiddles(evaluations[0].domain.half_coset), - ); - let (proof, prover_opening_positions) = prover.decommit(&mut test_channel()); - let decommitment_values = zip(&evaluations, prover_opening_positions.values().rev()) - .map(|(poly, positions)| open_polynomial(poly, positions)) - .collect(); - let bounds = LOG_DEGREES.map(CirclePolyDegreeBound::new).to_vec(); - - let mut verifier = FriVerifier::commit(&mut test_channel(), config, proof, bounds).unwrap(); - let verifier_opening_positions = verifier.column_query_positions(&mut test_channel()); - - assert_eq!(prover_opening_positions, verifier_opening_positions); - verifier.decommit(decommitment_values) - } - - #[test] - fn proof_with_removed_layer_fails_verification() { - const LOG_DEGREE: u32 = 6; - let evaluation = polynomial_evaluation(6, LOG_BLOWUP_FACTOR); - let log_domain_size = evaluation.domain.log_size(); - let queries = Queries::from_positions(vec![1], log_domain_size); - let config = FriConfig::new(2, LOG_BLOWUP_FACTOR, queries.len()); - let prover = FriProver::commit( - &mut test_channel(), - config, - &[evaluation.clone()], - &CpuBackend::precompute_twiddles(evaluation.domain.half_coset), - ); - let proof = prover.decommit_on_queries(&queries); - let bound = vec![CirclePolyDegreeBound::new(LOG_DEGREE)]; - // Set verifier's config to expect one extra layer than prover config. - let mut invalid_config = config; - invalid_config.log_last_layer_degree_bound -= 1; - - let verifier = FriVerifier::commit(&mut test_channel(), invalid_config, proof, bound); - - assert!(matches!( - verifier, - Err(FriVerificationError::InvalidNumFriLayers) - )); - } - - #[test] - fn proof_with_added_layer_fails_verification() { - const LOG_DEGREE: u32 = 6; - let evaluation = polynomial_evaluation(LOG_DEGREE, LOG_BLOWUP_FACTOR); - let log_domain_size = evaluation.domain.log_size(); - let queries = Queries::from_positions(vec![1], log_domain_size); - let config = FriConfig::new(2, LOG_BLOWUP_FACTOR, queries.len()); - let prover = FriProver::commit( - &mut test_channel(), - config, - &[evaluation.clone()], - &CpuBackend::precompute_twiddles(evaluation.domain.half_coset), - ); - let proof = prover.decommit_on_queries(&queries); - let bound = vec![CirclePolyDegreeBound::new(LOG_DEGREE)]; - // Set verifier's config to expect one less layer than prover config. - let mut invalid_config = config; - invalid_config.log_last_layer_degree_bound += 1; - - let verifier = FriVerifier::commit(&mut test_channel(), invalid_config, proof, bound); - - assert!(matches!( - verifier, - Err(FriVerificationError::InvalidNumFriLayers) - )); - } - - #[test] - fn proof_with_invalid_inner_layer_evaluation_fails_verification() { - const LOG_DEGREE: u32 = 6; - let evaluation = polynomial_evaluation(LOG_DEGREE, LOG_BLOWUP_FACTOR); - let log_domain_size = evaluation.domain.log_size(); - let queries = Queries::from_positions(vec![5], log_domain_size); - let config = FriConfig::new(2, LOG_BLOWUP_FACTOR, queries.len()); - let decommitment_value = query_polynomial(&evaluation, &queries); - let prover = FriProver::commit( - &mut test_channel(), - config, - &[evaluation.clone()], - &CpuBackend::precompute_twiddles(evaluation.domain.half_coset), - ); - let bound = vec![CirclePolyDegreeBound::new(LOG_DEGREE)]; - let mut proof = prover.decommit_on_queries(&queries); - // Remove an evaluation from the second layer's proof. - proof.inner_layers[1].evals_subset.pop(); - let verifier = FriVerifier::commit(&mut test_channel(), config, proof, bound).unwrap(); - - let verification_result = verifier.decommit_on_queries(&queries, vec![decommitment_value]); - - assert!(matches!( - verification_result, - Err(FriVerificationError::InnerLayerEvaluationsInvalid { layer: 1 }) - )); - } - - #[test] - fn proof_with_invalid_inner_layer_decommitment_fails_verification() { - const LOG_DEGREE: u32 = 6; - let evaluation = polynomial_evaluation(LOG_DEGREE, LOG_BLOWUP_FACTOR); - let log_domain_size = evaluation.domain.log_size(); - let queries = Queries::from_positions(vec![5], log_domain_size); - let config = FriConfig::new(2, LOG_BLOWUP_FACTOR, queries.len()); - let decommitment_value = query_polynomial(&evaluation, &queries); - let prover = FriProver::commit( - &mut test_channel(), - config, - &[evaluation.clone()], - &CpuBackend::precompute_twiddles(evaluation.domain.half_coset), - ); - let bound = vec![CirclePolyDegreeBound::new(LOG_DEGREE)]; - let mut proof = prover.decommit_on_queries(&queries); - // Modify the committed values in the second layer. - proof.inner_layers[1].evals_subset[0] += BaseField::one(); - let verifier = FriVerifier::commit(&mut test_channel(), config, proof, bound).unwrap(); - - let verification_result = verifier.decommit_on_queries(&queries, vec![decommitment_value]); - - assert!(matches!( - verification_result, - Err(FriVerificationError::InnerLayerCommitmentInvalid { layer: 1, .. }) - )); - } - - #[test] - fn proof_with_invalid_last_layer_degree_fails_verification() { - const LOG_DEGREE: u32 = 6; - const LOG_MAX_LAST_LAYER_DEGREE: u32 = 2; - let evaluation = polynomial_evaluation(LOG_DEGREE, LOG_BLOWUP_FACTOR); - let log_domain_size = evaluation.domain.log_size(); - let queries = Queries::from_positions(vec![1, 7, 8], log_domain_size); - let config = FriConfig::new(LOG_MAX_LAST_LAYER_DEGREE, LOG_BLOWUP_FACTOR, queries.len()); - let prover = FriProver::commit( - &mut test_channel(), - config, - &[evaluation.clone()], - &CpuBackend::precompute_twiddles(evaluation.domain.half_coset), - ); - let bound = vec![CirclePolyDegreeBound::new(LOG_DEGREE)]; - let mut proof = prover.decommit_on_queries(&queries); - let bad_last_layer_coeffs = vec![One::one(); 1 << (LOG_MAX_LAST_LAYER_DEGREE + 1)]; - proof.last_layer_poly = LinePoly::new(bad_last_layer_coeffs); - - let verifier = FriVerifier::commit(&mut test_channel(), config, proof, bound); - - assert!(matches!( - verifier, - Err(FriVerificationError::LastLayerDegreeInvalid) - )); - } - - #[test] - fn proof_with_invalid_last_layer_fails_verification() { - const LOG_DEGREE: u32 = 6; - let evaluation = polynomial_evaluation(LOG_DEGREE, LOG_BLOWUP_FACTOR); - let log_domain_size = evaluation.domain.log_size(); - let queries = Queries::from_positions(vec![1, 7, 8], log_domain_size); - let config = FriConfig::new(2, LOG_BLOWUP_FACTOR, queries.len()); - let decommitment_value = query_polynomial(&evaluation, &queries); - let prover = FriProver::commit( - &mut test_channel(), - config, - &[evaluation.clone()], - &CpuBackend::precompute_twiddles(evaluation.domain.half_coset), - ); - let bound = vec![CirclePolyDegreeBound::new(LOG_DEGREE)]; - let mut proof = prover.decommit_on_queries(&queries); - // Compromise the last layer polynomial's first coefficient. - proof.last_layer_poly[0] += BaseField::one(); - let verifier = FriVerifier::commit(&mut test_channel(), config, proof, bound).unwrap(); - - let verification_result = verifier.decommit_on_queries(&queries, vec![decommitment_value]); - - assert!(matches!( - verification_result, - Err(FriVerificationError::LastLayerEvaluationsInvalid) - )); - } - - #[test] - #[should_panic] - fn decommit_queries_on_invalid_domain_fails_verification() { - const LOG_DEGREE: u32 = 3; - let evaluation = polynomial_evaluation(LOG_DEGREE, LOG_BLOWUP_FACTOR); - let log_domain_size = evaluation.domain.log_size(); - let queries = Queries::from_positions(vec![5], log_domain_size); - let config = FriConfig::new(1, LOG_BLOWUP_FACTOR, queries.len()); - let decommitment_value = query_polynomial(&evaluation, &queries); - let prover = FriProver::commit( - &mut test_channel(), - config, - &[evaluation.clone()], - &CpuBackend::precompute_twiddles(evaluation.domain.half_coset), - ); - let proof = prover.decommit_on_queries(&queries); - let bound = vec![CirclePolyDegreeBound::new(LOG_DEGREE)]; - let verifier = FriVerifier::commit(&mut test_channel(), config, proof, bound).unwrap(); - // Simulate the verifier sampling queries on a smaller domain. - let mut invalid_queries = queries.clone(); - invalid_queries.log_domain_size -= 1; - - let _ = verifier.decommit_on_queries(&invalid_queries, vec![decommitment_value]); - } - - /// Returns an evaluation of a random polynomial with degree `2^log_degree`. - /// - /// The evaluation domain size is `2^(log_degree + log_blowup_factor)`. - fn polynomial_evaluation( - log_degree: u32, - log_blowup_factor: u32, - ) -> SecureEvaluation { - let poly = CpuCirclePoly::new(vec![BaseField::one(); 1 << log_degree]); - let coset = Coset::half_odds(log_degree + log_blowup_factor - 1); - let domain = CircleDomain::new(coset); - let values = poly.evaluate(domain); - SecureEvaluation::new(domain, values.into_iter().map(SecureField::from).collect()) - } - - /// Returns the log degree bound of a polynomial. - fn log_degree_bound(polynomial: LineEvaluation) -> u32 { - let coeffs = polynomial.interpolate().into_ordered_coefficients(); - let degree = coeffs.into_iter().rposition(|c| !c.is_zero()).unwrap_or(0); - (degree + 1).ilog2() - } - - // TODO: Remove after SubcircleDomain integration. - fn query_polynomial( - polynomial: &SecureEvaluation, - queries: &Queries, - ) -> SparseCircleEvaluation { - let polynomial_log_size = polynomial.domain.log_size(); - let positions = - get_opening_positions(queries, &[queries.log_domain_size, polynomial_log_size]); - open_polynomial(polynomial, &positions[&polynomial_log_size]) - } - - fn open_polynomial( - polynomial: &SecureEvaluation, - positions: &SparseSubCircleDomain, - ) -> SparseCircleEvaluation { - let coset_evals = positions - .iter() - .map(|position| { - let coset_domain = position.to_circle_domain(&polynomial.domain); - let evals = coset_domain - .iter_indices() - .map(|p| { - polynomial.at(bit_reverse_index( - polynomial.domain.find(p).unwrap(), - polynomial.domain.log_size(), - )) - }) - .collect(); - let coset_eval = - CpuCircleEvaluation::::new(coset_domain, evals); - coset_eval.bit_reverse() - }) - .collect(); - - SparseCircleEvaluation::new(coset_evals) - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/lookups/gkr_prover.rs b/Stwo_wrapper/crates/prover/src/core/lookups/gkr_prover.rs deleted file mode 100644 index 6e6ed25..0000000 --- a/Stwo_wrapper/crates/prover/src/core/lookups/gkr_prover.rs +++ /dev/null @@ -1,566 +0,0 @@ -//! GKR batch prover for Grand Product and LogUp lookup arguments. -use std::borrow::Cow; -use std::iter::{successors, zip}; -use std::ops::Deref; - -use educe::Educe; -use itertools::Itertools; -use num_traits::{One, Zero}; -use thiserror::Error; - -use super::gkr_verifier::{GkrArtifact, GkrBatchProof, GkrMask}; -use super::mle::{Mle, MleOps}; -use super::sumcheck::MultivariatePolyOracle; -use super::utils::{eq, random_linear_combination, UnivariatePoly}; -use crate::core::backend::{Col, Column, ColumnOps, CpuBackend}; -use crate::core::channel::Channel; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::{Field, FieldExpOps}; -use crate::core::lookups::sumcheck; - -pub trait GkrOps: MleOps + MleOps { - /// Returns evaluations `eq(x, y) * v` for all `x` in `{0, 1}^n`. - /// - /// Note [`Mle`] stores values in bit-reversed order. - /// - /// [`eq(x, y)`]: crate::core::lookups::utils::eq - fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Mle; - - /// Generates the next GKR layer from the current one. - fn next_layer(layer: &Layer) -> Layer; - - /// Returns univariate polynomial `f(t) = sum_x h(t, x)` for all `x` in the boolean hypercube. - /// - /// `claim` equals `f(0) + f(1)`. - /// - /// For more context see docs of [`MultivariatePolyOracle::sum_as_poly_in_first_variable()`]. - fn sum_as_poly_in_first_variable( - h: &GkrMultivariatePolyOracle<'_, Self>, - claim: SecureField, - ) -> UnivariatePoly; -} - -/// Stores evaluations of [`eq(x, y)`] on all boolean hypercube points of the form -/// `x = (0, x_1, ..., x_{n-1})`. -/// -/// Evaluations are stored in bit-reversed order i.e. `evals[0] = eq((0, ..., 0, 0), y)`, -/// `evals[1] = eq((0, ..., 0, 1), y)`, etc. -/// -/// [`eq(x, y)`]: crate::core::lookups::utils::eq -#[derive(Educe)] -#[educe(Debug, Clone)] -pub struct EqEvals> { - y: Vec, - evals: Mle, -} - -impl EqEvals { - pub fn generate(y: &[SecureField]) -> Self { - let y = y.to_vec(); - - if y.is_empty() { - let evals = Mle::new([SecureField::one()].into_iter().collect()); - return Self { evals, y }; - } - - let evals = B::gen_eq_evals(&y[1..], eq(&[SecureField::zero()], &[y[0]])); - assert_eq!(evals.len(), 1 << (y.len() - 1)); - Self { evals, y } - } - - /// Returns the fixed vector `y` used to generate the evaluations. - pub fn y(&self) -> &[SecureField] { - &self.y - } -} - -impl> Deref for EqEvals { - type Target = Col; - - fn deref(&self) -> &Col { - &self.evals - } -} - -/// Represents a layer in a binary tree structured GKR circuit. -/// -/// Layers can contain multiple columns, for example [LogUp] which has separate columns for -/// numerators and denominators. -/// -/// [LogUp]: https://eprint.iacr.org/2023/1284.pdf -#[derive(Educe)] -#[educe(Debug, Clone)] -pub enum Layer { - GrandProduct(Mle), - LogUpGeneric { - numerators: Mle, - denominators: Mle, - }, - LogUpMultiplicities { - numerators: Mle, - denominators: Mle, - }, - /// All numerators implicitly equal "1". - LogUpSingles { - denominators: Mle, - }, -} - -impl Layer { - /// Returns the number of variables used to interpolate the layer's gate values. - pub fn n_variables(&self) -> usize { - match self { - Self::GrandProduct(mle) - | Self::LogUpSingles { denominators: mle } - | Self::LogUpMultiplicities { - denominators: mle, .. - } - | Self::LogUpGeneric { - denominators: mle, .. - } => mle.n_variables(), - } - } - - fn is_output_layer(&self) -> bool { - self.n_variables() == 0 - } - - /// Produces the next layer from the current layer. - /// - /// The next layer is strictly half the size of the current layer. - /// Returns [`None`] if called on an output layer. - pub fn next_layer(&self) -> Option { - if self.is_output_layer() { - return None; - } - - Some(B::next_layer(self)) - } - - /// Returns each column output if the layer is an output layer, otherwise returns an `Err`. - fn try_into_output_layer_values(self) -> Result, NotOutputLayerError> { - if !self.is_output_layer() { - return Err(NotOutputLayerError); - } - - Ok(match self { - Layer::LogUpSingles { denominators } => { - let numerator = SecureField::one(); - let denominator = denominators.at(0); - vec![numerator, denominator] - } - Layer::LogUpMultiplicities { - numerators, - denominators, - } => { - let numerator = numerators.at(0).into(); - let denominator = denominators.at(0); - vec![numerator, denominator] - } - Layer::LogUpGeneric { - numerators, - denominators, - } => { - let numerator = numerators.at(0); - let denominator = denominators.at(0); - vec![numerator, denominator] - } - Layer::GrandProduct(col) => { - vec![col.at(0)] - } - }) - } - - /// Returns a transformed layer with the first variable of each column fixed to `assignment`. - fn fix_first_variable(self, x0: SecureField) -> Self { - if self.n_variables() == 0 { - return self; - } - - match self { - Self::GrandProduct(mle) => Self::GrandProduct(mle.fix_first_variable(x0)), - Self::LogUpGeneric { - numerators, - denominators, - } => Self::LogUpGeneric { - numerators: numerators.fix_first_variable(x0), - denominators: denominators.fix_first_variable(x0), - }, - Self::LogUpMultiplicities { - numerators, - denominators, - } => Self::LogUpGeneric { - numerators: numerators.fix_first_variable(x0), - denominators: denominators.fix_first_variable(x0), - }, - Self::LogUpSingles { denominators } => Self::LogUpSingles { - denominators: denominators.fix_first_variable(x0), - }, - } - } - - /// Represents the next GKR layer evaluation as a multivariate polynomial which uses this GKR - /// layer as input. - /// - /// Layers can contain multiple columns `c_0, ..., c_{n-1}` with multivariate polynomial `g_i` - /// representing[^note] `c_i` in the next layer. These polynomials must be combined with - /// `lambda` into a single polynomial `h = g_0 + lambda * g_1 + ... + lambda^(n-1) * - /// g_{n-1}`. The oracle for `h` should be returned. - /// - /// # Optimization: precomputed [`eq(x, y)`] evals - /// - /// Let `y` be a fixed vector of length `m` and let `z` be a subvector comprising of the - /// last `k` elements of `y`. `h(x)` **must** equal some multivariate polynomial of the form - /// `eq(x, z) * p(x)`. A common operation will be computing the univariate polynomial `f(t) = - /// sum_x h(t, x)` for `x` in the boolean hypercube `{0, 1}^(k-1)`. - /// - /// `eq_evals` stores evaluations of `eq((0, x), y)` for `x` in a potentially extended boolean - /// hypercube `{0, 1}^{m-1}`. These evaluations, on the extended hypercube, can be used directly - /// in computing the sums of `h(x)`, however a correction factor must be applied to the final - /// sum which is handled by [`correct_sum_as_poly_in_first_variable()`] in `O(m)`. - /// - /// Being able to compute sums of `h(x)` using `eq_evals` in this way leads to a more efficient - /// implementation because the prover only has to generate `eq_evals` once for an entire batch - /// of multiple GKR layer instances. - /// - /// [`eq(x, y)`]: crate::core::lookups::utils::eq - /// [^note]: By "representing" we mean `g_i` agrees with the next layer's `c_i` on the boolean - /// hypercube that interpolates `c_i`. - fn into_multivariate_poly( - self, - lambda: SecureField, - eq_evals: &EqEvals, - ) -> GkrMultivariatePolyOracle<'_, B> { - GkrMultivariatePolyOracle { - eq_evals: Cow::Borrowed(eq_evals), - input_layer: self, - eq_fixed_var_correction: SecureField::one(), - lambda, - } - } - - /// Returns a copy of this layer with the [`CpuBackend`]. - /// - /// This operation is expensive but can be useful for small traces that are difficult to handle - /// depending on the backend. For example, the SIMD backend offloads to the CPU backend when - /// trace length becomes smaller than the SIMD lane count. - pub fn to_cpu(&self) -> Layer { - match self { - Layer::GrandProduct(mle) => Layer::GrandProduct(Mle::new(mle.to_cpu())), - Layer::LogUpGeneric { - numerators, - denominators, - } => Layer::LogUpGeneric { - numerators: Mle::new(numerators.to_cpu()), - denominators: Mle::new(denominators.to_cpu()), - }, - Layer::LogUpMultiplicities { - numerators, - denominators, - } => Layer::LogUpMultiplicities { - numerators: Mle::new(numerators.to_cpu()), - denominators: Mle::new(denominators.to_cpu()), - }, - Layer::LogUpSingles { denominators } => Layer::LogUpSingles { - denominators: Mle::new(denominators.to_cpu()), - }, - } - } -} - -#[derive(Debug)] -struct NotOutputLayerError; - -/// Multivariate polynomial `P` that expresses the relation between two consecutive GKR layers. -/// -/// When the input layer is [`Layer::GrandProduct`] (represented by multilinear column `inp`) -/// the polynomial represents: -/// -/// ```text -/// P(x) = eq(x, y) * inp(x, 0) * inp(x, 1) -/// ``` -/// -/// When the input layer is LogUp (represented by multilinear columns `inp_numer` and -/// `inp_denom`) the polynomial represents: -/// -/// ```text -/// numer(x) = inp_numer(x, 0) * inp_denom(x, 1) + inp_numer(x, 1) * inp_denom(x, 0) -/// denom(x) = inp_denom(x, 0) * inp_denom(x, 1) -/// -/// P(x) = eq(x, y) * (numer(x) + lambda * denom(x)) -/// ``` -pub struct GkrMultivariatePolyOracle<'a, B: GkrOps> { - /// `eq_evals` passed by `Layer::into_multivariate_poly()`. - pub eq_evals: Cow<'a, EqEvals>, - pub input_layer: Layer, - pub eq_fixed_var_correction: SecureField, - /// Used by LogUp to perform a random linear combination of the numerators and denominators. - pub lambda: SecureField, -} - -impl<'a, B: GkrOps> MultivariatePolyOracle for GkrMultivariatePolyOracle<'a, B> { - fn n_variables(&self) -> usize { - self.input_layer.n_variables() - 1 - } - - fn sum_as_poly_in_first_variable(&self, claim: SecureField) -> UnivariatePoly { - B::sum_as_poly_in_first_variable(self, claim) - } - - fn fix_first_variable(self, challenge: SecureField) -> Self { - if self.is_constant() { - return self; - } - - let z0 = self.eq_evals.y()[self.eq_evals.y().len() - self.n_variables()]; - let eq_fixed_var_correction = self.eq_fixed_var_correction * eq(&[challenge], &[z0]); - - Self { - eq_evals: self.eq_evals, - eq_fixed_var_correction, - input_layer: self.input_layer.fix_first_variable(challenge), - lambda: self.lambda, - } - } -} - -impl<'a, B: GkrOps> GkrMultivariatePolyOracle<'a, B> { - fn is_constant(&self) -> bool { - self.n_variables() == 0 - } - - /// Returns all input layer columns restricted to a line. - /// - /// Let `l` be the line satisfying `l(0) = b*` and `l(1) = c*`. Oracles that represent constants - /// are expressed by values `c_i(b*)` and `c_i(c*)` where `c_i` represents the input GKR layer's - /// `i`th column (for binary tree GKR `b* = (r, 0)`, `c* = (r, 1)`). - /// - /// If this oracle represents a constant, then each `c_i` restricted to `l` is returned. - /// Otherwise, an [`Err`] is returned. - /// - /// For more context see page 64. - fn try_into_mask(self) -> Result { - if !self.is_constant() { - return Err(NotConstantPolyError); - } - - let columns = match self.input_layer { - Layer::GrandProduct(mle) => vec![mle.to_cpu().try_into().unwrap()], - Layer::LogUpGeneric { - numerators, - denominators, - } => { - let numerators = numerators.to_cpu().try_into().unwrap(); - let denominators = denominators.to_cpu().try_into().unwrap(); - vec![numerators, denominators] - } - // Should never get called. - Layer::LogUpMultiplicities { .. } => unimplemented!(), - Layer::LogUpSingles { denominators } => { - let numerators = [SecureField::one(); 2]; - let denominators = denominators.to_cpu().try_into().unwrap(); - vec![numerators, denominators] - } - }; - - Ok(GkrMask::new(columns)) - } - - /// Returns a copy of this oracle with the [`CpuBackend`]. - /// - /// This operation is expensive but can be useful for small oracles that are difficult to handle - /// depending on the backend. For example, the SIMD backend offloads to the CPU backend when - /// trace length becomes smaller than the SIMD lane count. - pub fn to_cpu(&self) -> GkrMultivariatePolyOracle<'a, CpuBackend> { - // TODO(andrew): This block is not ideal. - let n_eq_evals = 1 << (self.n_variables() - 1); - let eq_evals = Cow::Owned(EqEvals { - evals: Mle::new((0..n_eq_evals).map(|i| self.eq_evals.at(i)).collect()), - y: self.eq_evals.y.to_vec(), - }); - - GkrMultivariatePolyOracle { - eq_evals, - eq_fixed_var_correction: self.eq_fixed_var_correction, - input_layer: self.input_layer.to_cpu(), - lambda: self.lambda, - } - } -} - -/// Error returned when a polynomial is expected to be constant but it is not. -#[derive(Debug, Error)] -#[error("polynomial is not constant")] -pub struct NotConstantPolyError; - -/// Batch proves lookup circuits with GKR. -/// -/// The input layers should be committed to the channel before calling this function. -// GKR algorithm: (page 64) -pub fn prove_batch( - channel: &mut impl Channel, - input_layer_by_instance: Vec>, -) -> (GkrBatchProof, GkrArtifact) { - let n_instances = input_layer_by_instance.len(); - let n_layers_by_instance = input_layer_by_instance - .iter() - .map(|l| l.n_variables()) - .collect_vec(); - let n_layers = *n_layers_by_instance.iter().max().unwrap(); - - // Evaluate all instance circuits and collect the layer values. - let mut layers_by_instance = input_layer_by_instance - .into_iter() - .map(|input_layer| gen_layers(input_layer).into_iter().rev()) - .collect_vec(); - - let mut output_claims_by_instance = vec![None; n_instances]; - let mut layer_masks_by_instance = (0..n_instances).map(|_| Vec::new()).collect_vec(); - let mut sumcheck_proofs = Vec::new(); - - let mut ood_point = Vec::new(); - let mut claims_to_verify_by_instance = vec![None; n_instances]; - - for layer in 0..n_layers { - let n_remaining_layers = n_layers - layer; - - // Check all the instances for output layers. - for (instance, layers) in layers_by_instance.iter_mut().enumerate() { - if n_layers_by_instance[instance] == n_remaining_layers { - let output_layer = layers.next().unwrap(); - let output_layer_values = output_layer.try_into_output_layer_values().unwrap(); - claims_to_verify_by_instance[instance] = Some(output_layer_values.clone()); - output_claims_by_instance[instance] = Some(output_layer_values); - } - } - - // Seed the channel with layer claims. - for claims_to_verify in claims_to_verify_by_instance.iter().flatten() { - channel.mix_felts(claims_to_verify); - } - - let eq_evals = EqEvals::generate(&ood_point); - let sumcheck_alpha = channel.draw_felt(); - let instance_lambda = channel.draw_felt(); - - let mut sumcheck_oracles = Vec::new(); - let mut sumcheck_claims = Vec::new(); - let mut sumcheck_instances = Vec::new(); - - // Create the multivariate polynomial oracles used with sumcheck. - for (instance, claims_to_verify) in claims_to_verify_by_instance.iter().enumerate() { - if let Some(claims_to_verify) = claims_to_verify { - let layer = layers_by_instance[instance].next().unwrap(); - sumcheck_oracles.push(layer.into_multivariate_poly(instance_lambda, &eq_evals)); - sumcheck_claims.push(random_linear_combination(claims_to_verify, instance_lambda)); - sumcheck_instances.push(instance); - } - } - - let (sumcheck_proof, sumcheck_ood_point, constant_poly_oracles, _) = - sumcheck::prove_batch(sumcheck_claims, sumcheck_oracles, sumcheck_alpha, channel); - - sumcheck_proofs.push(sumcheck_proof); - - let masks = constant_poly_oracles - .into_iter() - .map(|oracle| oracle.try_into_mask().unwrap()) - .collect_vec(); - - // Seed the channel with the layer masks. - for (&instance, mask) in zip(&sumcheck_instances, &masks) { - channel.mix_felts(mask.columns().flatten()); - layer_masks_by_instance[instance].push(mask.clone()); - } - - let challenge = channel.draw_felt(); - ood_point = sumcheck_ood_point; - ood_point.push(challenge); - - // Set the claims to prove in the layer above. - for (instance, mask) in zip(sumcheck_instances, masks) { - claims_to_verify_by_instance[instance] = Some(mask.reduce_at_point(challenge)); - } - } - - let output_claims_by_instance = output_claims_by_instance - .into_iter() - .map(Option::unwrap) - .collect(); - - let claims_to_verify_by_instance = claims_to_verify_by_instance - .into_iter() - .map(Option::unwrap) - .collect(); - - let proof = GkrBatchProof { - sumcheck_proofs, - layer_masks_by_instance, - output_claims_by_instance, - }; - - let artifact = GkrArtifact { - ood_point, - claims_to_verify_by_instance, - n_variables_by_instance: n_layers_by_instance, - }; - - (proof, artifact) -} - -/// Executes the GKR circuit on the input layer and returns all the circuit's layers. -fn gen_layers(input_layer: Layer) -> Vec> { - let n_variables = input_layer.n_variables(); - let layers = successors(Some(input_layer), |layer| layer.next_layer()).collect_vec(); - assert_eq!(layers.len(), n_variables + 1); - layers -} - -/// Computes `r(t) = sum_x eq((t, x), y[-k:]) * p(t, x)` from evaluations of -/// `f(t) = sum_x eq(({0}^(n - k), 0, x), y) * p(t, x)`. -/// -/// Note `claim` must equal `r(0) + r(1)` and `r` must have degree <= 3. -/// -/// For more context see `Layer::into_multivariate_poly()` docs. -/// See also (section 3.2). -pub fn correct_sum_as_poly_in_first_variable( - f_at_0: SecureField, - f_at_2: SecureField, - claim: SecureField, - y: &[SecureField], - k: usize, -) -> UnivariatePoly { - assert_ne!(k, 0); - let n = y.len(); - assert!(k <= n); - - // We evaluated `f(0)` and `f(2)` - the inputs. - // We want to compute `r(t) = f(t) * eq(t, y[n - k]) / eq(0, y[:n - k + 1])`. - let a_const = eq(&vec![SecureField::zero(); n - k + 1], &y[..n - k + 1]).inverse(); - - // Find the additional root of `r(t)`, by finding the root of `eq(t, y[n - k])`: - // 0 = eq(t, y[n - k]) - // = t * y[n - k] + (1 - t)(1 - y[n - k]) - // = 1 - y[n - k] - t(1 - 2 * y[n - k]) - // => t = (1 - y[n - k]) / (1 - 2 * y[n - k]) - // = b - let b_const = (SecureField::one() - y[n - k]) / (SecureField::one() - y[n - k].double()); - - // We get that `r(t) = f(t) * eq(t, y[n - k]) * a`. - let r_at_0 = f_at_0 * eq(&[SecureField::zero()], &[y[n - k]]) * a_const; - let r_at_1 = claim - r_at_0; - let r_at_2 = f_at_2 * eq(&[BaseField::from(2).into()], &[y[n - k]]) * a_const; - let r_at_b = SecureField::zero(); - - // Interpolate. - UnivariatePoly::interpolate_lagrange( - &[ - SecureField::zero(), - SecureField::one(), - SecureField::from(BaseField::from(2)), - b_const, - ], - &[r_at_0, r_at_1, r_at_2, r_at_b], - ) -} diff --git a/Stwo_wrapper/crates/prover/src/core/lookups/gkr_verifier.rs b/Stwo_wrapper/crates/prover/src/core/lookups/gkr_verifier.rs deleted file mode 100644 index b65ceb1..0000000 --- a/Stwo_wrapper/crates/prover/src/core/lookups/gkr_verifier.rs +++ /dev/null @@ -1,357 +0,0 @@ -//! GKR batch verifier for Grand Product and LogUp lookup arguments. -use thiserror::Error; - -use super::sumcheck::{SumcheckError, SumcheckProof}; -use super::utils::{eq, fold_mle_evals, random_linear_combination}; -use crate::core::channel::Channel; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::lookups::sumcheck; -use crate::core::lookups::utils::Fraction; - -/// Partially verifies a batch GKR proof. -/// -/// On successful verification the function returns a [`GkrArtifact`] which stores the out-of-domain -/// point and claimed evaluations in the input layer columns for each instance at the OOD point. -/// These claimed evaluations are not checked in this function - hence partial verification. -pub fn partially_verify_batch( - gate_by_instance: Vec, - proof: &GkrBatchProof, - channel: &mut impl Channel, -) -> Result { - let GkrBatchProof { - sumcheck_proofs, - layer_masks_by_instance, - output_claims_by_instance, - } = proof; - - if layer_masks_by_instance.len() != output_claims_by_instance.len() { - return Err(GkrError::MalformedProof); - } - - let n_instances = layer_masks_by_instance.len(); - let instance_n_layers = |instance: usize| layer_masks_by_instance[instance].len(); - let n_layers = (0..n_instances).map(instance_n_layers).max().unwrap(); - - if n_layers != sumcheck_proofs.len() { - return Err(GkrError::MalformedProof); - } - - if gate_by_instance.len() != n_instances { - return Err(GkrError::NumInstancesMismatch { - given: gate_by_instance.len(), - proof: n_instances, - }); - } - - let mut ood_point = vec![]; - let mut claims_to_verify_by_instance = vec![None; n_instances]; - - for (layer, sumcheck_proof) in sumcheck_proofs.iter().enumerate() { - let n_remaining_layers = n_layers - layer; - - // Check for output layers. - for instance in 0..n_instances { - if instance_n_layers(instance) == n_remaining_layers { - let output_claims = output_claims_by_instance[instance].clone(); - claims_to_verify_by_instance[instance] = Some(output_claims); - } - } - - // Seed the channel with layer claims. - for claims_to_verify in claims_to_verify_by_instance.iter().flatten() { - channel.mix_felts(claims_to_verify); - } - - let sumcheck_alpha = channel.draw_felt(); - let instance_lambda = channel.draw_felt(); - - let mut sumcheck_claims = Vec::new(); - let mut sumcheck_instances = Vec::new(); - - // Prepare the sumcheck claim. - for (instance, claims_to_verify) in claims_to_verify_by_instance.iter().enumerate() { - if let Some(claims_to_verify) = claims_to_verify { - let n_unused_variables = n_layers - instance_n_layers(instance); - let doubling_factor = BaseField::from(1 << n_unused_variables); - let claim = - random_linear_combination(claims_to_verify, instance_lambda) * doubling_factor; - sumcheck_claims.push(claim); - sumcheck_instances.push(instance); - } - } - - let sumcheck_claim = random_linear_combination(&sumcheck_claims, sumcheck_alpha); - let (sumcheck_ood_point, sumcheck_eval) = - sumcheck::partially_verify(sumcheck_claim, sumcheck_proof, channel) - .map_err(|source| GkrError::InvalidSumcheck { layer, source })?; - - let mut layer_evals = Vec::new(); - - // Evaluate the circuit locally at sumcheck OOD point. - for &instance in &sumcheck_instances { - let n_unused = n_layers - instance_n_layers(instance); - let mask = &layer_masks_by_instance[instance][layer - n_unused]; - let gate = &gate_by_instance[instance]; - let gate_output = gate.eval(mask).map_err(|InvalidNumMaskColumnsError| { - let instance_layer = instance_n_layers(layer) - n_remaining_layers; - GkrError::InvalidMask { - instance, - instance_layer, - } - })?; - // TODO: Consider simplifying the code by just using the same eq eval for all instances - // regardless of size. - let eq_eval = eq(&ood_point[n_unused..], &sumcheck_ood_point[n_unused..]); - layer_evals.push(eq_eval * random_linear_combination(&gate_output, instance_lambda)); - } - - let layer_eval = random_linear_combination(&layer_evals, sumcheck_alpha); - - if sumcheck_eval != layer_eval { - return Err(GkrError::CircuitCheckFailure { - claim: sumcheck_eval, - output: layer_eval, - layer, - }); - } - - // Seed the channel with the layer masks. - for &instance in &sumcheck_instances { - let n_unused = n_layers - instance_n_layers(instance); - let mask = &layer_masks_by_instance[instance][layer - n_unused]; - channel.mix_felts(mask.columns().flatten()); - } - - // Set the OOD evaluation point for layer above. - let challenge = channel.draw_felt(); - ood_point = sumcheck_ood_point; - ood_point.push(challenge); - - // Set the claims to verify in the layer above. - for instance in sumcheck_instances { - let n_unused = n_layers - instance_n_layers(instance); - let mask = &layer_masks_by_instance[instance][layer - n_unused]; - claims_to_verify_by_instance[instance] = Some(mask.reduce_at_point(challenge)); - } - } - - let claims_to_verify_by_instance = claims_to_verify_by_instance - .into_iter() - .map(Option::unwrap) - .collect(); - - Ok(GkrArtifact { - ood_point, - claims_to_verify_by_instance, - n_variables_by_instance: (0..n_instances).map(instance_n_layers).collect(), - }) -} - -/// Batch GKR proof. -pub struct GkrBatchProof { - /// Sum-check proof for each layer. - pub sumcheck_proofs: Vec, - /// Mask for each layer for each instance. - pub layer_masks_by_instance: Vec>, - /// Column circuit outputs for each instance. - pub output_claims_by_instance: Vec>, -} - -/// Values of interest obtained from the execution of the GKR protocol. -pub struct GkrArtifact { - /// Out-of-domain (OOD) point for evaluating columns in the input layer. - pub ood_point: Vec, - /// The claimed evaluation at `ood_point` for each column in the input layer of each instance. - pub claims_to_verify_by_instance: Vec>, - /// The number of variables that interpolate the input layer of each instance. - pub n_variables_by_instance: Vec, -} - -/// Defines how a circuit operates locally on two input rows to produce a single output row. -/// This local 2-to-1 constraint is what gives the whole circuit its "binary tree" structure. -/// -/// Binary tree structured circuits have a highly regular wiring pattern that fit the structure of -/// the circuits defined in [Thaler13] which allow for efficient linear time (linear in size of the -/// circuit) GKR prover implementations. -/// -/// [Thaler13]: https://eprint.iacr.org/2013/351.pdf -#[derive(Debug, Clone, Copy)] -pub enum Gate { - LogUp, - GrandProduct, -} - -impl Gate { - /// Returns the output after applying the gate to the mask. - fn eval(&self, mask: &GkrMask) -> Result, InvalidNumMaskColumnsError> { - Ok(match self { - Self::LogUp => { - if mask.columns().len() != 2 { - return Err(InvalidNumMaskColumnsError); - } - - let [numerator_a, numerator_b] = mask.columns()[0]; - let [denominator_a, denominator_b] = mask.columns()[1]; - - let a = Fraction::new(numerator_a, denominator_a); - let b = Fraction::new(numerator_b, denominator_b); - let res = a + b; - - vec![res.numerator, res.denominator] - } - Self::GrandProduct => { - if mask.columns().len() != 1 { - return Err(InvalidNumMaskColumnsError); - } - - let [a, b] = mask.columns()[0]; - vec![a * b] - } - }) - } -} - -/// Mask has an invalid number of columns -#[derive(Debug)] -struct InvalidNumMaskColumnsError; - -/// Stores two evaluations of each column in a GKR layer. -#[derive(Debug, Clone)] -pub struct GkrMask { - columns: Vec<[SecureField; 2]>, -} - -impl GkrMask { - pub fn new(columns: Vec<[SecureField; 2]>) -> Self { - Self { columns } - } - - pub fn to_rows(&self) -> [Vec; 2] { - self.columns.iter().map(|[a, b]| (a, b)).unzip().into() - } - - pub fn columns(&self) -> &[[SecureField; 2]] { - &self.columns - } - - /// Returns all `p_i(x)` where `p_i` interpolates column `i` of the mask on `{0, 1}`. - pub fn reduce_at_point(&self, x: SecureField) -> Vec { - self.columns - .iter() - .map(|&[v0, v1]| fold_mle_evals(x, v0, v1)) - .collect() - } -} - -/// Error encountered during GKR protocol verification. -#[derive(Error, Debug)] -pub enum GkrError { - /// The proof is malformed. - #[error("proof data is invalid")] - MalformedProof, - /// Mask has an invalid number of columns. - #[error("mask in layer {instance_layer} of instance {instance} is invalid")] - InvalidMask { - instance: usize, - /// Layer of the instance (but not necessarily the batch). - instance_layer: LayerIndex, - }, - /// There is a mismatch between the number of instances in the proof and the number of - /// instances passed for verification. - #[error("provided an invalid number of instances (given {given}, proof expects {proof})")] - NumInstancesMismatch { given: usize, proof: usize }, - /// There was an error with one of the sumcheck proofs. - #[error("sum-check invalid in layer {layer}: {source}")] - InvalidSumcheck { - layer: LayerIndex, - source: SumcheckError, - }, - /// The circuit polynomial the verifier evaluated doesn't match claim from sumcheck. - #[error("circuit check failed in layer {layer} (calculated {output}, claim {claim})")] - CircuitCheckFailure { - claim: SecureField, - output: SecureField, - layer: LayerIndex, - }, -} - -/// GKR layer index where 0 corresponds to the output layer. -pub type LayerIndex = usize; - -#[cfg(test)] -mod tests { - use super::{partially_verify_batch, Gate, GkrArtifact, GkrError}; - use crate::core::backend::CpuBackend; - use crate::core::channel::Channel; - use crate::core::fields::qm31::SecureField; - use crate::core::lookups::gkr_prover::{prove_batch, Layer}; - use crate::core::lookups::mle::Mle; - use crate::core::test_utils::test_channel; - - #[test] - fn prove_batch_works() -> Result<(), GkrError> { - const LOG_N: usize = 5; - let mut channel = test_channel(); - let col0 = Mle::::new(channel.draw_felts(1 << LOG_N)); - let col1 = Mle::::new(channel.draw_felts(1 << LOG_N)); - let product0 = col0.iter().product::(); - let product1 = col1.iter().product::(); - let input_layers = vec![ - Layer::GrandProduct(col0.clone()), - Layer::GrandProduct(col1.clone()), - ]; - let (proof, _) = prove_batch(&mut test_channel(), input_layers); - - let GkrArtifact { - ood_point, - claims_to_verify_by_instance, - n_variables_by_instance, - } = partially_verify_batch(vec![Gate::GrandProduct; 2], &proof, &mut test_channel())?; - - assert_eq!(n_variables_by_instance, [LOG_N, LOG_N]); - assert_eq!(proof.output_claims_by_instance.len(), 2); - assert_eq!(claims_to_verify_by_instance.len(), 2); - assert_eq!(proof.output_claims_by_instance[0], &[product0]); - assert_eq!(proof.output_claims_by_instance[1], &[product1]); - let claim0 = &claims_to_verify_by_instance[0]; - let claim1 = &claims_to_verify_by_instance[1]; - assert_eq!(claim0, &[col0.eval_at_point(&ood_point)]); - assert_eq!(claim1, &[col1.eval_at_point(&ood_point)]); - Ok(()) - } - - #[test] - fn prove_batch_with_different_sizes_works() -> Result<(), GkrError> { - const LOG_N0: usize = 5; - const LOG_N1: usize = 7; - let mut channel = test_channel(); - let col0 = Mle::::new(channel.draw_felts(1 << LOG_N0)); - let col1 = Mle::::new(channel.draw_felts(1 << LOG_N1)); - let product0 = col0.iter().product::(); - let product1 = col1.iter().product::(); - let input_layers = vec![ - Layer::GrandProduct(col0.clone()), - Layer::GrandProduct(col1.clone()), - ]; - let (proof, _) = prove_batch(&mut test_channel(), input_layers); - - let GkrArtifact { - ood_point, - claims_to_verify_by_instance, - n_variables_by_instance, - } = partially_verify_batch(vec![Gate::GrandProduct; 2], &proof, &mut test_channel())?; - - assert_eq!(n_variables_by_instance, [LOG_N0, LOG_N1]); - assert_eq!(proof.output_claims_by_instance.len(), 2); - assert_eq!(claims_to_verify_by_instance.len(), 2); - assert_eq!(proof.output_claims_by_instance[0], &[product0]); - assert_eq!(proof.output_claims_by_instance[1], &[product1]); - let claim0 = &claims_to_verify_by_instance[0]; - let claim1 = &claims_to_verify_by_instance[1]; - let n_vars = ood_point.len(); - assert_eq!(claim0, &[col0.eval_at_point(&ood_point[n_vars - LOG_N0..])]); - assert_eq!(claim1, &[col1.eval_at_point(&ood_point[n_vars - LOG_N1..])]); - Ok(()) - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/lookups/mle.rs b/Stwo_wrapper/crates/prover/src/core/lookups/mle.rs deleted file mode 100644 index 7449f40..0000000 --- a/Stwo_wrapper/crates/prover/src/core/lookups/mle.rs +++ /dev/null @@ -1,106 +0,0 @@ -use std::ops::{Deref, DerefMut}; - -use educe::Educe; - -use crate::core::backend::{Col, Column, ColumnOps}; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::Field; - -pub trait MleOps: ColumnOps + Sized { - /// Returns a transformed [`Mle`] where the first variable is fixed to `assignment`. - fn fix_first_variable(mle: Mle, assignment: SecureField) -> Mle - where - Self: MleOps; -} - -/// Multilinear Extension stored as evaluations of a multilinear polynomial over the boolean -/// hypercube in bit-reversed order. -#[derive(Educe)] -#[educe(Debug, Clone)] -pub struct Mle, F: Field> { - evals: Col, -} - -impl, F: Field> Mle { - /// Creates a [`Mle`] from evaluations of a multilinear polynomial on the boolean hypercube. - /// - /// # Panics - /// - /// Panics if the number of evaluations is not a power of two. - pub fn new(evals: Col) -> Self { - assert!(evals.len().is_power_of_two()); - Self { evals } - } - - pub fn into_evals(self) -> Col { - self.evals - } - - /// Returns a transformed polynomial where the first variable is fixed to `assignment`. - pub fn fix_first_variable(self, assignment: SecureField) -> Mle - where - B: MleOps, - { - B::fix_first_variable(self, assignment) - } - - /// Returns the number of variables in the polynomial. - pub fn n_variables(&self) -> usize { - self.evals.len().ilog2() as usize - } -} - -impl, F: Field> Deref for Mle { - type Target = Col; - - fn deref(&self) -> &Col { - &self.evals - } -} - -impl, F: Field> DerefMut for Mle { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.evals - } -} - -#[cfg(test)] -mod test { - use super::{Mle, MleOps}; - use crate::core::backend::Column; - use crate::core::fields::qm31::SecureField; - use crate::core::fields::{ExtensionOf, Field}; - - impl Mle - where - F: Field, - SecureField: ExtensionOf, - B: MleOps, - { - /// Evaluates the multilinear polynomial at `point`. - pub(crate) fn eval_at_point(&self, point: &[SecureField]) -> SecureField { - pub fn eval(mle_evals: &[SecureField], p: &[SecureField]) -> SecureField { - match p { - [] => mle_evals[0], - &[p_i, ref p @ ..] => { - let (lhs, rhs) = mle_evals.split_at(mle_evals.len() / 2); - let lhs_eval = eval(lhs, p); - let rhs_eval = eval(rhs, p); - // Equivalent to `eq(0, p_i) * lhs_eval + eq(1, p_i) * rhs_eval`. - p_i * (rhs_eval - lhs_eval) + lhs_eval - } - } - } - - let mle_evals = self - .clone() - .into_evals() - .to_cpu() - .into_iter() - .map(|v| v.into()) - .collect::>(); - - eval(&mle_evals, point) - } - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/lookups/mod.rs b/Stwo_wrapper/crates/prover/src/core/lookups/mod.rs deleted file mode 100644 index 8f7351a..0000000 --- a/Stwo_wrapper/crates/prover/src/core/lookups/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -pub mod gkr_prover; -pub mod gkr_verifier; -pub mod mle; -pub mod sumcheck; -pub mod utils; diff --git a/Stwo_wrapper/crates/prover/src/core/lookups/sumcheck.rs b/Stwo_wrapper/crates/prover/src/core/lookups/sumcheck.rs deleted file mode 100644 index 1df2451..0000000 --- a/Stwo_wrapper/crates/prover/src/core/lookups/sumcheck.rs +++ /dev/null @@ -1,292 +0,0 @@ -//! Sum-check protocol that proves and verifies claims about `sum_x g(x)` for all x in `{0, 1}^n`. -//! -//! [`MultivariatePolyOracle`] provides methods for evaluating sums and making transformations on -//! `g` in the context of the protocol. It is intended to be used in conjunction with -//! [`prove_batch()`] to generate proofs. - -use std::iter::zip; - -use itertools::Itertools; -use num_traits::{One, Zero}; -use thiserror::Error; - -use super::utils::UnivariatePoly; -use crate::core::channel::Channel; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; - -/// Something that can be seen as a multivariate polynomial `g(x_0, ..., x_{n-1})`. -pub trait MultivariatePolyOracle: Sized { - /// Returns the number of variables in `g`. - fn n_variables(&self) -> usize; - - /// Computes the sum of `g(x_0, x_1, ..., x_{n-1})` over all `(x_1, ..., x_{n-1})` in - /// `{0, 1}^(n-1)`, effectively reducing the sum over `g` to a univariate polynomial in `x_0`. - /// - /// `claim` equals the claimed sum of `g(x_0, x_2, ..., x_{n-1})` over all `(x_0, ..., x_{n-1})` - /// in `{0, 1}^n`. Knowing the claim can help optimize the implementation: Let `f` denote the - /// univariate polynomial we want to return. Note that `claim = f(0) + f(1)` so knowing `claim` - /// and either `f(0)` or `f(1)` allows determining the other. - fn sum_as_poly_in_first_variable(&self, claim: SecureField) -> UnivariatePoly; - - /// Returns a transformed oracle where the first variable of `g` is fixed to `challenge`. - /// - /// The returned oracle represents the multivariate polynomial `g'`, defined as - /// `g'(x_1, ..., x_{n-1}) = g(challenge, x_1, ..., x_{n-1})`. - fn fix_first_variable(self, challenge: SecureField) -> Self; -} - -/// Performs sum-check on a random linear combinations of multiple multivariate polynomials. -/// -/// Let the multivariate polynomials be `g_0, ..., g_{n-1}`. A single sum-check is performed on -/// multivariate polynomial `h = g_0 + lambda * g_1 + ... + lambda^(n-1) * g_{n-1}`. The `g_i`s do -/// not need to have the same number of variables. `g_i`s with less variables are folded in the -/// latest possible round of the protocol. For instance with `g_0(x, y, z)` and `g_1(x, y)` -/// sum-check is performed on `h(x, y, z) = g_0(x, y, z) + lambda * g_1(y, z)`. Claim `c_i` should -/// equal the claimed sum of `g_i(x_0, ..., x_{j-1})` over all `(x_0, ..., x_{j-1})` in `{0, 1}^j`. -/// -/// The degree of each `g_i` should not exceed [`MAX_DEGREE`] in any variable. The sum-check proof -/// of `h`, list of challenges (variable assignment) and the constant oracles (i.e. the `g_i` with -/// all variables fixed to the their corresponding challenges) are returned. -/// -/// Output is of the form: `(proof, variable_assignment, constant_poly_oracles, claimed_evals)` -/// -/// # Panics -/// -/// Panics if: -/// - No multivariate polynomials are provided. -/// - There aren't the same number of multivariate polynomials and claims. -/// - The degree of any multivariate polynomial exceeds [`MAX_DEGREE`] in any variable. -/// - The round polynomials are inconsistent with their corresponding claimed sum on `0` and `1`. -// TODO: Consider returning constant oracles as separate type. -pub fn prove_batch( - mut claims: Vec, - mut multivariate_polys: Vec, - lambda: SecureField, - channel: &mut impl Channel, -) -> (SumcheckProof, Vec, Vec, Vec) { - let n_variables = multivariate_polys.iter().map(O::n_variables).max().unwrap(); - assert_eq!(claims.len(), multivariate_polys.len()); - - let mut round_polys = Vec::new(); - let mut assignment = Vec::new(); - - // Update the claims for the sum over `h`'s hypercube. - for (claim, multivariate_poly) in zip(&mut claims, &multivariate_polys) { - let n_unused_variables = n_variables - multivariate_poly.n_variables(); - *claim *= BaseField::from(1 << n_unused_variables); - } - - // Prove sum-check rounds - for round in 0..n_variables { - let n_remaining_rounds = n_variables - round; - - let this_round_polys = zip(&multivariate_polys, &claims) - .enumerate() - .map(|(i, (multivariate_poly, &claim))| { - let round_poly = if n_remaining_rounds == multivariate_poly.n_variables() { - multivariate_poly.sum_as_poly_in_first_variable(claim) - } else { - (claim / BaseField::from(2)).into() - }; - - let eval_at_0 = round_poly.eval_at_point(SecureField::zero()); - let eval_at_1 = round_poly.eval_at_point(SecureField::one()); - assert_eq!(eval_at_0 + eval_at_1, claim, "i={i}, round={round}"); - assert!(round_poly.degree() <= MAX_DEGREE, "i={i}, round={round}"); - - round_poly - }) - .collect_vec(); - - let round_poly = random_linear_combination(&this_round_polys, lambda); - - channel.mix_felts(&round_poly); - - let challenge = channel.draw_felt(); - - claims = this_round_polys - .iter() - .map(|round_poly| round_poly.eval_at_point(challenge)) - .collect(); - - multivariate_polys = multivariate_polys - .into_iter() - .map(|multivariate_poly| { - if n_remaining_rounds != multivariate_poly.n_variables() { - return multivariate_poly; - } - - multivariate_poly.fix_first_variable(challenge) - }) - .collect(); - - round_polys.push(round_poly); - assignment.push(challenge); - } - - let proof = SumcheckProof { round_polys }; - - (proof, assignment, multivariate_polys, claims) -} - -/// Returns `p_0 + alpha * p_1 + ... + alpha^(n-1) * p_{n-1}`. -fn random_linear_combination( - polys: &[UnivariatePoly], - alpha: SecureField, -) -> UnivariatePoly { - polys - .iter() - .rfold(Zero::zero(), |acc, poly| acc * alpha + poly.clone()) -} - -/// Partially verifies a sum-check proof. -/// -/// Only "partial" since it does not fully verify the prover's claimed evaluation on the variable -/// assignment but checks if the sum of the round polynomials evaluated on `0` and `1` matches the -/// claim for each round. If the proof passes these checks, the variable assignment and the prover's -/// claimed evaluation are returned for the caller to validate otherwise an [`Err`] is returned. -/// -/// Output is of the form `(variable_assignment, claimed_eval)`. -pub fn partially_verify( - mut claim: SecureField, - proof: &SumcheckProof, - channel: &mut impl Channel, -) -> Result<(Vec, SecureField), SumcheckError> { - let mut assignment = Vec::new(); - - for (round, round_poly) in proof.round_polys.iter().enumerate() { - if round_poly.degree() > MAX_DEGREE { - return Err(SumcheckError::DegreeInvalid { round }); - } - - // TODO: optimize this by sending one less coefficient, and computing it from the - // claim, instead of checking the claim. (Can also be done by quotienting). - let sum = round_poly.eval_at_point(Zero::zero()) + round_poly.eval_at_point(One::one()); - - if claim != sum { - return Err(SumcheckError::SumInvalid { claim, sum, round }); - } - - channel.mix_felts(round_poly); - let challenge = channel.draw_felt(); - claim = round_poly.eval_at_point(challenge); - assignment.push(challenge); - } - - Ok((assignment, claim)) -} - -#[derive(Debug, Clone)] -pub struct SumcheckProof { - pub round_polys: Vec>, -} - -/// Max degree of polynomials the verifier accepts in each round of the protocol. -pub const MAX_DEGREE: usize = 3; - -/// Sum-check protocol verification error. -#[derive(Error, Debug)] -pub enum SumcheckError { - #[error("degree of the polynomial in round {round} is too high")] - DegreeInvalid { round: RoundIndex }, - #[error("sum does not match the claim in round {round} (sum {sum}, claim {claim})")] - SumInvalid { - claim: SecureField, - sum: SecureField, - round: RoundIndex, - }, -} - -/// Sum-check round index where 0 corresponds to the first round. -pub type RoundIndex = usize; - -#[cfg(test)] -mod tests { - - use num_traits::One; - - use crate::core::backend::CpuBackend; - use crate::core::channel::{Blake2sChannel, Channel}; - use crate::core::fields::qm31::SecureField; - use crate::core::fields::Field; - use crate::core::lookups::mle::Mle; - use crate::core::lookups::sumcheck::{partially_verify, prove_batch}; - - #[test] - fn sumcheck_works() { - let values = test_channel().draw_felts(32); - let claim = values.iter().sum(); - let mle = Mle::::new(values); - let lambda = SecureField::one(); - let (proof, ..) = prove_batch(vec![claim], vec![mle.clone()], lambda, &mut test_channel()); - - let (assignment, eval) = partially_verify(claim, &proof, &mut test_channel()).unwrap(); - - assert_eq!(eval, mle.eval_at_point(&assignment)); - } - - #[test] - fn batch_sumcheck_works() { - let mut channel = test_channel(); - let values0 = channel.draw_felts(32); - let values1 = channel.draw_felts(32); - let claim0 = values0.iter().sum(); - let claim1 = values1.iter().sum(); - let mle0 = Mle::::new(values0.clone()); - let mle1 = Mle::::new(values1.clone()); - let lambda = channel.draw_felt(); - let claims = vec![claim0, claim1]; - let mles = vec![mle0.clone(), mle1.clone()]; - let (proof, ..) = prove_batch(claims, mles, lambda, &mut test_channel()); - - let claim = claim0 + lambda * claim1; - let (assignment, eval) = partially_verify(claim, &proof, &mut test_channel()).unwrap(); - - let eval0 = mle0.eval_at_point(&assignment); - let eval1 = mle1.eval_at_point(&assignment); - assert_eq!(eval, eval0 + lambda * eval1); - } - - #[test] - fn batch_sumcheck_with_different_n_variables() { - let mut channel = test_channel(); - let values0 = channel.draw_felts(64); - let values1 = channel.draw_felts(32); - let claim0 = values0.iter().sum(); - let claim1 = values1.iter().sum(); - let mle0 = Mle::::new(values0.clone()); - let mle1 = Mle::::new(values1.clone()); - let lambda = channel.draw_felt(); - let claims = vec![claim0, claim1]; - let mles = vec![mle0.clone(), mle1.clone()]; - let (proof, ..) = prove_batch(claims, mles, lambda, &mut test_channel()); - - let claim = claim0 + lambda * claim1.double(); - let (assignment, eval) = partially_verify(claim, &proof, &mut test_channel()).unwrap(); - - let eval0 = mle0.eval_at_point(&assignment); - let eval1 = mle1.eval_at_point(&assignment[1..]); - assert_eq!(eval, eval0 + lambda * eval1); - } - - #[test] - fn invalid_sumcheck_proof_fails() { - let values = test_channel().draw_felts(8); - let claim = values.iter().sum::(); - let lambda = SecureField::one(); - // Compromise the first value. - let mut invalid_values = values; - invalid_values[0] += SecureField::one(); - let invalid_claim = vec![invalid_values.iter().sum::()]; - let invalid_mle = vec![Mle::::new(invalid_values.clone())]; - let (invalid_proof, ..) = - prove_batch(invalid_claim, invalid_mle, lambda, &mut test_channel()); - - assert!(partially_verify(claim, &invalid_proof, &mut test_channel()).is_err()); - } - - fn test_channel() -> Blake2sChannel { - Blake2sChannel::default() - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/lookups/utils.rs b/Stwo_wrapper/crates/prover/src/core/lookups/utils.rs deleted file mode 100644 index 85ea4c3..0000000 --- a/Stwo_wrapper/crates/prover/src/core/lookups/utils.rs +++ /dev/null @@ -1,356 +0,0 @@ -use std::iter::{zip, Sum}; -use std::ops::{Add, Deref, Mul, Neg, Sub}; - -use num_traits::{One, Zero}; - -use crate::core::fields::qm31::SecureField; -use crate::core::fields::{ExtensionOf, Field}; - -/// Univariate polynomial stored as coefficients in the monomial basis. -#[derive(Debug, Clone)] -pub struct UnivariatePoly(Vec); - -impl UnivariatePoly { - pub fn new(coeffs: Vec) -> Self { - let mut polynomial = Self(coeffs); - polynomial.truncate_leading_zeros(); - polynomial - } - - pub fn eval_at_point(&self, x: F) -> F { - horner_eval(&self.0, x) - } - - // - pub fn interpolate_lagrange(xs: &[F], ys: &[F]) -> Self { - assert_eq!(xs.len(), ys.len()); - - let mut coeffs = Self::zero(); - - for (i, (&xi, &yi)) in zip(xs, ys).enumerate() { - let mut prod = yi; - - for (j, &xj) in xs.iter().enumerate() { - if i != j { - prod /= xi - xj; - } - } - - let mut term = Self::new(vec![prod]); - - for (j, &xj) in xs.iter().enumerate() { - if i != j { - term = term * (Self::x() - Self::new(vec![xj])); - } - } - - coeffs = coeffs + term; - } - - coeffs.truncate_leading_zeros(); - - coeffs - } - - pub fn degree(&self) -> usize { - let mut coeffs = self.0.iter().rev(); - let _ = (&mut coeffs).take_while(|v| v.is_zero()); - coeffs.len().saturating_sub(1) - } - - fn x() -> Self { - Self(vec![F::zero(), F::one()]) - } - - fn truncate_leading_zeros(&mut self) { - while self.0.last() == Some(&F::zero()) { - self.0.pop(); - } - } -} - -impl From for UnivariatePoly { - fn from(value: F) -> Self { - Self::new(vec![value]) - } -} - -impl Mul for UnivariatePoly { - type Output = Self; - - fn mul(mut self, rhs: F) -> Self { - self.0.iter_mut().for_each(|coeff| *coeff *= rhs); - self - } -} - -impl Mul for UnivariatePoly { - type Output = Self; - - fn mul(mut self, mut rhs: Self) -> Self { - if self.is_zero() || rhs.is_zero() { - return Self::zero(); - } - - self.truncate_leading_zeros(); - rhs.truncate_leading_zeros(); - - let mut res = vec![F::zero(); self.0.len() + rhs.0.len() - 1]; - - for (i, coeff_a) in self.0.into_iter().enumerate() { - for (j, &coeff_b) in rhs.0.iter().enumerate() { - res[i + j] += coeff_a * coeff_b; - } - } - - Self::new(res) - } -} - -impl Add for UnivariatePoly { - type Output = Self; - - fn add(self, rhs: Self) -> Self { - let n = self.0.len().max(rhs.0.len()); - let mut res = Vec::new(); - - for i in 0..n { - res.push(match (self.0.get(i), rhs.0.get(i)) { - (Some(&a), Some(&b)) => a + b, - (Some(&a), None) | (None, Some(&a)) => a, - _ => unreachable!(), - }) - } - - Self(res) - } -} - -impl Sub for UnivariatePoly { - type Output = Self; - - fn sub(self, rhs: Self) -> Self { - self + (-rhs) - } -} - -impl Neg for UnivariatePoly { - type Output = Self; - - fn neg(self) -> Self { - Self(self.0.into_iter().map(|v| -v).collect()) - } -} - -impl Zero for UnivariatePoly { - fn zero() -> Self { - Self(vec![]) - } - - fn is_zero(&self) -> bool { - self.0.iter().all(F::is_zero) - } -} - -impl Deref for UnivariatePoly { - type Target = [F]; - - fn deref(&self) -> &[F] { - &self.0 - } -} - -/// Evaluates univariate polynomial using [Horner's method]. -/// -/// [Horner's method]: https://en.wikipedia.org/wiki/Horner%27s_method -pub fn horner_eval(coeffs: &[F], x: F) -> F { - coeffs - .iter() - .rfold(F::zero(), |acc, &coeff| acc * x + coeff) -} - -/// Returns `v_0 + alpha * v_1 + ... + alpha^(n-1) * v_{n-1}`. -pub fn random_linear_combination(v: &[SecureField], alpha: SecureField) -> SecureField { - horner_eval(v, alpha) -} - -/// Evaluates the lagrange kernel of the boolean hypercube. -/// -/// The lagrange kernel of the boolean hypercube is a multilinear extension of the function that -/// when given `x, y` in `{0, 1}^n` evaluates to 1 if `x = y`, and evaluates to 0 otherwise. -pub fn eq(x: &[F], y: &[F]) -> F { - assert_eq!(x.len(), y.len()); - zip(x, y) - .map(|(&xi, &yi)| xi * yi + (F::one() - xi) * (F::one() - yi)) - .product() -} - -/// Computes `eq(0, assignment) * eval0 + eq(1, assignment) * eval1`. -pub fn fold_mle_evals(assignment: SecureField, eval0: F, eval1: F) -> SecureField -where - F: Field, - SecureField: ExtensionOf, -{ - assignment * (eval1 - eval0) + eval0 -} - -/// Projective fraction. -#[derive(Debug, Clone, Copy)] -pub struct Fraction { - pub numerator: N, - pub denominator: D, -} - -impl Fraction { - pub fn new(numerator: N, denominator: D) -> Self { - Self { - numerator, - denominator, - } - } -} - -impl + Add + Mul + Mul + Copy> Add - for Fraction -{ - type Output = Fraction; - - fn add(self, rhs: Self) -> Fraction { - Fraction { - numerator: rhs.denominator * self.numerator + self.denominator * rhs.numerator, - denominator: self.denominator * rhs.denominator, - } - } -} - -impl Zero for Fraction -where - Self: Add, -{ - fn zero() -> Self { - Self { - numerator: N::zero(), - denominator: D::one(), - } - } - - fn is_zero(&self) -> bool { - self.numerator.is_zero() && !self.denominator.is_zero() - } -} - -impl Sum for Fraction -where - Self: Zero, -{ - fn sum>(mut iter: I) -> Self { - let first = iter.next().unwrap_or_else(Self::zero); - iter.fold(first, |a, b| a + b) - } -} - -/// Represents the fraction `1 / x` -pub struct Reciprocal { - x: T, -} - -impl Reciprocal { - pub fn new(x: T) -> Self { - Self { x } - } -} - -impl + Mul + Copy> Add for Reciprocal { - type Output = Fraction; - - fn add(self, rhs: Self) -> Fraction { - // `1/a + 1/b = (a + b)/(a * b)` - Fraction { - numerator: self.x + rhs.x, - denominator: self.x * rhs.x, - } - } -} - -#[cfg(test)] -mod tests { - use std::iter::zip; - - use num_traits::{One, Zero}; - - use super::{horner_eval, UnivariatePoly}; - use crate::core::fields::m31::BaseField; - use crate::core::fields::qm31::SecureField; - use crate::core::fields::FieldExpOps; - use crate::core::lookups::utils::{eq, Fraction}; - - #[test] - fn lagrange_interpolation_works() { - let xs = [5, 1, 3, 9].map(BaseField::from); - let ys = [1, 2, 3, 4].map(BaseField::from); - - let poly = UnivariatePoly::interpolate_lagrange(&xs, &ys); - - for (x, y) in zip(xs, ys) { - assert_eq!(poly.eval_at_point(x), y, "mismatch for x={x}"); - } - } - - #[test] - fn horner_eval_works() { - let coeffs = [BaseField::from(9), BaseField::from(2), BaseField::from(3)]; - let x = BaseField::from(7); - - let eval = horner_eval(&coeffs, x); - - assert_eq!(eval, coeffs[0] + coeffs[1] * x + coeffs[2] * x.square()); - } - - #[test] - fn eq_identical_hypercube_points_returns_one() { - let zero = SecureField::zero(); - let one = SecureField::one(); - let a = &[one, zero, one]; - - let eq_eval = eq(a, a); - - assert_eq!(eq_eval, one); - } - - #[test] - fn eq_different_hypercube_points_returns_zero() { - let zero = SecureField::zero(); - let one = SecureField::one(); - let a = &[one, zero, one]; - let b = &[one, zero, zero]; - - let eq_eval = eq(a, b); - - assert_eq!(eq_eval, zero); - } - - #[test] - #[should_panic] - fn eq_different_size_points() { - let zero = SecureField::zero(); - let one = SecureField::one(); - - eq(&[zero, one], &[zero]); - } - - #[test] - fn fraction_addition_works() { - let a = Fraction::new(BaseField::from(1), BaseField::from(3)); - let b = Fraction::new(BaseField::from(2), BaseField::from(6)); - - let Fraction { - numerator, - denominator, - } = a + b; - - assert_eq!( - numerator / denominator, - BaseField::from(2) / BaseField::from(3) - ); - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/mod.rs b/Stwo_wrapper/crates/prover/src/core/mod.rs deleted file mode 100644 index a00aad6..0000000 --- a/Stwo_wrapper/crates/prover/src/core/mod.rs +++ /dev/null @@ -1,59 +0,0 @@ -use std::ops::{Deref, DerefMut}; - -pub mod air; -pub mod backend; -pub mod channel; -pub mod circle; -pub mod constraints; -pub mod fft; -pub mod fields; -pub mod fri; -pub mod lookups; -pub mod pcs; -pub mod poly; -pub mod proof_of_work; -pub mod prover; -pub mod queries; -#[cfg(test)] -pub mod test_utils; -pub mod utils; -pub mod vcs; - -/// A vector in which each element relates (by index) to a column in the trace. -pub type ColumnVec = Vec; - -/// A vector of [ColumnVec]s. Each [ColumnVec] relates (by index) to a component in the air. -#[derive(Debug, Clone)] -pub struct ComponentVec(pub Vec>); - -impl ComponentVec { - pub fn flatten(self) -> ColumnVec { - self.0.into_iter().flatten().collect() - } -} - -impl ComponentVec> { - pub fn flatten_cols(self) -> Vec { - self.0.into_iter().flatten().flatten().collect() - } -} - -impl Default for ComponentVec { - fn default() -> Self { - Self(Vec::new()) - } -} - -impl Deref for ComponentVec { - type Target = Vec>; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for ComponentVec { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/pcs/mod.rs b/Stwo_wrapper/crates/prover/src/core/pcs/mod.rs deleted file mode 100644 index d9acf52..0000000 --- a/Stwo_wrapper/crates/prover/src/core/pcs/mod.rs +++ /dev/null @@ -1,40 +0,0 @@ -//! Implements a FRI polynomial commitment scheme. -//! This is a protocol where the prover can commit on a set of polynomials and then prove their -//! opening on a set of points. -//! Note: This implementation is not really a polynomial commitment scheme, because we are not in -//! the unique decoding regime. This is enough for a STARK proof though, where we only want to imply -//! the existence of such polynomials, and are ok with having a small decoding list. -//! Note: Opened points cannot come from the commitment domain. - -mod prover; -pub mod quotients; -mod utils; -mod verifier; - -pub use self::prover::{ - CommitmentSchemeProof, CommitmentSchemeProver, CommitmentTreeProver, TreeBuilder, -}; -pub use self::utils::TreeVec; -pub use self::verifier::CommitmentSchemeVerifier; -use super::fri::FriConfig; - -#[derive(Copy, Debug, Clone, PartialEq, Eq)] -pub struct TreeSubspan { - pub tree_index: usize, - pub col_start: usize, - pub col_end: usize, -} - -#[derive(Debug, Clone, Copy)] -pub struct PcsConfig { - pub pow_bits: u32, - pub fri_config: FriConfig, -} -impl Default for PcsConfig { - fn default() -> Self { - Self { - pow_bits: 5, - fri_config: FriConfig::new(0, 1, 3), - } - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/pcs/prover.rs b/Stwo_wrapper/crates/prover/src/core/pcs/prover.rs deleted file mode 100644 index ed45ffc..0000000 --- a/Stwo_wrapper/crates/prover/src/core/pcs/prover.rs +++ /dev/null @@ -1,256 +0,0 @@ -use std::collections::BTreeMap; - -use itertools::Itertools; -use tracing::{span, Level}; - -use super::super::circle::CirclePoint; -use super::super::fields::m31::BaseField; -use super::super::fields::qm31::SecureField; -use super::super::fri::{FriProof, FriProver}; -use super::super::poly::circle::CanonicCoset; -use super::super::poly::BitReversedOrder; -use super::super::ColumnVec; -use super::quotients::{compute_fri_quotients, PointSample}; -use super::utils::TreeVec; -use super::{PcsConfig, TreeSubspan}; -use crate::core::air::Trace; -use crate::core::backend::BackendForChannel; -use crate::core::channel::{Channel, MerkleChannel}; -use crate::core::poly::circle::{CircleEvaluation, CirclePoly}; -use crate::core::poly::twiddles::TwiddleTree; -use crate::core::vcs::ops::MerkleHasher; -use crate::core::vcs::prover::{MerkleDecommitment, MerkleProver}; - -/// The prover side of a FRI polynomial commitment scheme. See [super]. -pub struct CommitmentSchemeProver<'a, B: BackendForChannel, MC: MerkleChannel> { - pub trees: TreeVec>, - pub config: PcsConfig, - twiddles: &'a TwiddleTree, -} - -impl<'a, B: BackendForChannel, MC: MerkleChannel> CommitmentSchemeProver<'a, B, MC> { - pub fn new(config: PcsConfig, twiddles: &'a TwiddleTree) -> Self { - CommitmentSchemeProver { - trees: TreeVec::default(), - config, - twiddles, - } - } - - fn commit(&mut self, polynomials: ColumnVec>, channel: &mut MC::C) { - let _span = span!(Level::INFO, "Commitment").entered(); - let tree = CommitmentTreeProver::new( - polynomials, - self.config.fri_config.log_blowup_factor, - channel, - self.twiddles, - ); - self.trees.push(tree); - } - - pub fn tree_builder(&mut self) -> TreeBuilder<'_, 'a, B, MC> { - TreeBuilder { - tree_index: self.trees.len(), - commitment_scheme: self, - polys: Vec::default(), - } - } - - pub fn roots(&self) -> TreeVec<::Hash> { - self.trees.as_ref().map(|tree| tree.commitment.root()) - } - - pub fn polynomials(&self) -> TreeVec>> { - self.trees - .as_ref() - .map(|tree| tree.polynomials.iter().collect()) - } - - pub fn evaluations( - &self, - ) -> TreeVec>> { - self.trees - .as_ref() - .map(|tree| tree.evaluations.iter().collect()) - } - - pub fn trace(&self) -> Trace<'_, B> { - let polys = self.polynomials(); - let evals = self.evaluations(); - Trace { polys, evals } - } - - pub fn prove_values( - &self, - sampled_points: TreeVec>>>, - channel: &mut MC::C, - ) -> CommitmentSchemeProof { - // Evaluate polynomials on open points. - let span = span!(Level::INFO, "Evaluate columns out of domain").entered(); - let samples = self - .polynomials() - .zip_cols(&sampled_points) - .map_cols(|(poly, points)| { - points - .iter() - .map(|&point| PointSample { - point, - value: poly.eval_at_point(point), - }) - .collect_vec() - }); - span.exit(); - let sampled_values = samples - .as_cols_ref() - .map_cols(|x| x.iter().map(|o| o.value).collect()); - channel.mix_felts(&sampled_values.clone().flatten_cols()); - - // Compute oods quotients for boundary constraints on the sampled points. - let columns = self.evaluations().flatten(); - let quotients = compute_fri_quotients( - &columns, - &samples.flatten(), - channel.draw_felt(), - self.config.fri_config.log_blowup_factor, - ); - - // Run FRI commitment phase on the oods quotients. - let fri_prover = - FriProver::::commit(channel, self.config.fri_config, "ients, self.twiddles); - - // Proof of work. - let span1 = span!(Level::INFO, "Grind").entered(); - let proof_of_work = B::grind(channel, self.config.pow_bits); - span1.exit(); - channel.mix_nonce(proof_of_work); - - // FRI decommitment phase. - let (fri_proof, fri_query_domains) = fri_prover.decommit(channel); - - // Decommit the FRI queries on the merkle trees. - let decommitment_results = self.trees.as_ref().map(|tree| { - let queries = fri_query_domains - .iter() - .map(|(&log_size, domain)| (log_size, domain.flatten())) - .collect(); - tree.decommit(queries) - }); - - let queried_values = decommitment_results.as_ref().map(|(v, _)| v.clone()); - let decommitments = decommitment_results.map(|(_, d)| d); - - CommitmentSchemeProof { - sampled_values, - decommitments, - queried_values, - proof_of_work, - fri_proof, - } - } -} - -#[derive(Debug)] -pub struct CommitmentSchemeProof { - pub sampled_values: TreeVec>>, - pub decommitments: TreeVec>, - pub queried_values: TreeVec>>, - pub proof_of_work: u64, - pub fri_proof: FriProof, -} - -pub struct TreeBuilder<'a, 'b, B: BackendForChannel, MC: MerkleChannel> { - tree_index: usize, - commitment_scheme: &'a mut CommitmentSchemeProver<'b, B, MC>, - polys: ColumnVec>, -} -impl<'a, 'b, B: BackendForChannel, MC: MerkleChannel> TreeBuilder<'a, 'b, B, MC> { - pub fn extend_evals( - &mut self, - columns: ColumnVec>, - ) -> TreeSubspan { - let span = span!(Level::INFO, "Interpolation for commitment").entered(); - let col_start = self.polys.len(); - let polys = columns - .into_iter() - .map(|eval| eval.interpolate_with_twiddles(self.commitment_scheme.twiddles)) - .collect_vec(); - span.exit(); - self.polys.extend(polys); - TreeSubspan { - tree_index: self.tree_index, - col_start, - col_end: self.polys.len(), - } - } - - pub fn extend_polys(&mut self, polys: ColumnVec>) -> TreeSubspan { - let col_start = self.polys.len(); - self.polys.extend(polys); - TreeSubspan { - tree_index: self.tree_index, - col_start, - col_end: self.polys.len(), - } - } - - pub fn commit(self, channel: &mut MC::C) { - let _span = span!(Level::INFO, "Commitment").entered(); - self.commitment_scheme.commit(self.polys, channel); - } -} - -/// Prover data for a single commitment tree in a commitment scheme. The commitment scheme allows to -/// commit on a set of polynomials at a time. This corresponds to such a set. -pub struct CommitmentTreeProver, MC: MerkleChannel> { - pub polynomials: ColumnVec>, - pub evaluations: ColumnVec>, - pub commitment: MerkleProver, -} - -impl, MC: MerkleChannel> CommitmentTreeProver { - pub fn new( - polynomials: ColumnVec>, - log_blowup_factor: u32, - channel: &mut MC::C, - twiddles: &TwiddleTree, - ) -> Self { - let span = span!(Level::INFO, "Extension").entered(); - let evaluations = polynomials - .iter() - .map(|poly| { - poly.evaluate_with_twiddles( - CanonicCoset::new(poly.log_size() + log_blowup_factor).circle_domain(), - twiddles, - ) - }) - .collect_vec(); - - span.exit(); - - let _span = span!(Level::INFO, "Merkle").entered(); - let tree = MerkleProver::commit(evaluations.iter().map(|eval| &eval.values).collect()); - MC::mix_root(channel, tree.root()); - - CommitmentTreeProver { - polynomials, - evaluations, - commitment: tree, - } - } - - /// Decommits the merkle tree on the given query positions. - /// Returns the values at the queried positions and the decommitment. - /// The queries are given as a mapping from the log size of the layer size to the queried - /// positions on each column of that size. - fn decommit( - &self, - queries: BTreeMap>, - ) -> (ColumnVec>, MerkleDecommitment) { - let eval_vec = self - .evaluations - .iter() - .map(|eval| &eval.values) - .collect_vec(); - self.commitment.decommit(queries, eval_vec) - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/pcs/quotients.rs b/Stwo_wrapper/crates/prover/src/core/pcs/quotients.rs deleted file mode 100644 index 1034e05..0000000 --- a/Stwo_wrapper/crates/prover/src/core/pcs/quotients.rs +++ /dev/null @@ -1,218 +0,0 @@ -use std::cmp::Reverse; -use std::collections::BTreeMap; -use std::iter::zip; - -use itertools::{izip, multiunzip, Itertools}; -use tracing::{span, Level}; - -use crate::core::backend::cpu::quotients::{accumulate_row_quotients, quotient_constants}; -use crate::core::circle::CirclePoint; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::fri::SparseCircleEvaluation; -use crate::core::poly::circle::{ - CanonicCoset, CircleDomain, CircleEvaluation, PolyOps, SecureEvaluation, -}; -use crate::core::poly::BitReversedOrder; -use crate::core::prover::VerificationError; -use crate::core::queries::SparseSubCircleDomain; -use crate::core::utils::bit_reverse_index; - -pub trait QuotientOps: PolyOps { - /// Accumulates the quotients of the columns at the given domain. - /// For a column f(x), and a point sample (p,v), the quotient is - /// (f(x) - V0(x))/V1(x) - /// where V0(p)=v, V0(conj(p))=conj(v), and V1 is a vanishing polynomial for p,conj(p). - /// This ensures that if f(p)=v, then the quotient is a polynomial. - /// The result is a linear combination of the quotients using powers of random_coeff. - fn accumulate_quotients( - domain: CircleDomain, - columns: &[&CircleEvaluation], - random_coeff: SecureField, - sample_batches: &[ColumnSampleBatch], - log_blowup_factor: u32, - ) -> SecureEvaluation; -} - -/// A batch of column samplings at a point. -pub struct ColumnSampleBatch { - /// The point at which the columns are sampled. - pub point: CirclePoint, - /// The sampled column indices and their values at the point. - pub columns_and_values: Vec<(usize, SecureField)>, -} - -impl ColumnSampleBatch { - /// Groups column samples by sampled point. - /// # Arguments - /// samples: For each column, a vector of samples. - pub fn new_vec(samples: &[&Vec]) -> Vec { - // Group samples by point, and create a ColumnSampleBatch for each point. - // This should keep a stable ordering. - let mut grouped_samples = BTreeMap::new(); - for (column_index, samples) in samples.iter().enumerate() { - for sample in samples.iter() { - grouped_samples - .entry(sample.point) - .or_insert_with(Vec::new) - .push((column_index, sample.value)); - } - } - grouped_samples - .into_iter() - .map(|(point, columns_and_values)| ColumnSampleBatch { - point, - columns_and_values, - }) - .collect() - } -} - -pub struct PointSample { - pub point: CirclePoint, - pub value: SecureField, -} - -pub fn compute_fri_quotients( - columns: &[&CircleEvaluation], - samples: &[Vec], - random_coeff: SecureField, - log_blowup_factor: u32, -) -> Vec> { - let _span = span!(Level::INFO, "Compute FRI quotients").entered(); - zip(columns, samples) - .sorted_by_key(|(c, _)| Reverse(c.domain.log_size())) - .group_by(|(c, _)| c.domain.log_size()) - .into_iter() - .map(|(log_size, tuples)| { - let (columns, samples): (Vec<_>, Vec<_>) = tuples.unzip(); - let domain = CanonicCoset::new(log_size).circle_domain(); - // TODO: slice. - let sample_batches = ColumnSampleBatch::new_vec(&samples); - B::accumulate_quotients( - domain, - &columns, - random_coeff, - &sample_batches, - log_blowup_factor, - ) - }) - .collect() -} - -pub fn fri_answers( - column_log_sizes: Vec, - samples: &[Vec], - random_coeff: SecureField, - query_domain_per_log_size: BTreeMap, - queried_values_per_column: &[Vec], -) -> Result, VerificationError> { - izip!(column_log_sizes, samples, queried_values_per_column) - .sorted_by_key(|(log_size, ..)| Reverse(*log_size)) - .group_by(|(log_size, ..)| *log_size) - .into_iter() - .map(|(log_size, tuples)| { - let (_, samples, queried_valued_per_column): (Vec<_>, Vec<_>, Vec<_>) = - multiunzip(tuples); - fri_answers_for_log_size( - log_size, - &samples, // Here it's the points and sampled values (in the proof) - random_coeff, - &query_domain_per_log_size[&log_size], - &queried_valued_per_column, // Here it's queried values (in the proof) - ) - }) - .collect() -} - -pub fn fri_answers_for_log_size( - log_size: u32, - samples: &[&Vec], - random_coeff: SecureField, - query_domain: &SparseSubCircleDomain, - queried_values_per_column: &[&Vec], -) -> Result { - let commitment_domain = CanonicCoset::new(log_size).circle_domain(); - let sample_batches = ColumnSampleBatch::new_vec(samples); - for queried_values in queried_values_per_column { - if queried_values.len() != query_domain.flatten().len() { - return Err(VerificationError::InvalidStructure( - "Insufficient number of queried values".to_string(), - )); - } - } - let mut queried_values_per_column = queried_values_per_column - .iter() - .map(|q| q.iter()) - .collect_vec(); - - let mut evals = Vec::new(); - for subdomain in query_domain.iter() { - let domain = subdomain.to_circle_domain(&commitment_domain); - let quotient_constants = quotient_constants(&sample_batches, random_coeff, domain); - let mut column_evals = Vec::new(); - for queried_values in queried_values_per_column.iter_mut() { - let eval = CircleEvaluation::new( - domain, - queried_values.take(domain.size()).copied().collect_vec(), - ); - column_evals.push(eval); - } - - let mut values = Vec::new(); - for row in 0..domain.size() { - let domain_point = domain.at(bit_reverse_index(row, log_size)); - let value = accumulate_row_quotients( - &sample_batches, - &column_evals.iter().collect_vec(), - "ient_constants, - row, - domain_point, - ); - values.push(value); - } - let eval = CircleEvaluation::new(domain, values); - evals.push(eval); - } - - let res = SparseCircleEvaluation::new(evals); - if !queried_values_per_column.iter().all(|x| x.is_empty()) { - return Err(VerificationError::InvalidStructure( - "Too many queried values".to_string(), - )); - } - Ok(res) -} - -#[cfg(test)] -mod tests { - use crate::core::backend::cpu::{CpuCircleEvaluation, CpuCirclePoly}; - use crate::core::circle::SECURE_FIELD_CIRCLE_GEN; - use crate::core::pcs::quotients::{compute_fri_quotients, PointSample}; - use crate::core::poly::circle::CanonicCoset; - use crate::{m31, qm31}; - - #[test] - fn test_quotients_are_low_degree() { - const LOG_SIZE: u32 = 7; - const LOG_BLOWUP_FACTOR: u32 = 1; - let polynomial = CpuCirclePoly::new((0..1 << LOG_SIZE).map(|i| m31!(i)).collect()); - let eval_domain = CanonicCoset::new(LOG_SIZE + 1).circle_domain(); - let eval = polynomial.evaluate(eval_domain); - let point = SECURE_FIELD_CIRCLE_GEN; - let value = polynomial.eval_at_point(point); - let coeff = qm31!(1, 2, 3, 4); - let quot_eval = compute_fri_quotients( - &[&eval], - &[vec![PointSample { point, value }]], - coeff, - LOG_BLOWUP_FACTOR, - ) - .pop() - .unwrap(); - let quot_poly_base_field = - CpuCircleEvaluation::new(eval_domain, quot_eval.values.columns[0].clone()) - .interpolate(); - assert!(quot_poly_base_field.is_in_fri_space(LOG_SIZE)); - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/pcs/utils.rs b/Stwo_wrapper/crates/prover/src/core/pcs/utils.rs deleted file mode 100644 index bfdbdb5..0000000 --- a/Stwo_wrapper/crates/prover/src/core/pcs/utils.rs +++ /dev/null @@ -1,158 +0,0 @@ -use std::collections::BTreeSet; -use std::ops::{Deref, DerefMut}; - -use itertools::zip_eq; -use serde::{Deserialize, Serialize}; - -use super::TreeSubspan; -use crate::core::ColumnVec; - -/// A container that holds an element for each commitment tree. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TreeVec(pub Vec); - -impl TreeVec { - pub fn new(vec: Vec) -> TreeVec { - TreeVec(vec) - } - pub fn map U>(self, f: F) -> TreeVec { - TreeVec(self.0.into_iter().map(f).collect()) - } - pub fn zip(self, other: impl Into>) -> TreeVec<(T, U)> { - let other = other.into(); - TreeVec(self.0.into_iter().zip(other.0).collect()) - } - pub fn zip_eq(self, other: impl Into>) -> TreeVec<(T, U)> { - let other = other.into(); - TreeVec(zip_eq(self.0, other.0).collect()) - } - pub fn as_ref(&self) -> TreeVec<&T> { - TreeVec(self.iter().collect()) - } - pub fn as_mut(&mut self) -> TreeVec<&mut T> { - TreeVec(self.iter_mut().collect()) - } -} - -/// Converts `&TreeVec` to `TreeVec<&T>`. -impl<'a, T> From<&'a TreeVec> for TreeVec<&'a T> { - fn from(val: &'a TreeVec) -> Self { - val.as_ref() - } -} - -impl Deref for TreeVec { - type Target = Vec; - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for TreeVec { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -impl Default for TreeVec { - fn default() -> Self { - TreeVec(Vec::new()) - } -} - -impl TreeVec> { - pub fn map_cols U>(self, mut f: F) -> TreeVec> { - TreeVec( - self.0 - .into_iter() - .map(|column| column.into_iter().map(&mut f).collect()) - .collect(), - ) - } - - /// Zips two [`TreeVec>`] with the same structure (number of columns in each tree). - /// The resulting [`TreeVec>`] has the same structure, with each value being a tuple - /// of the corresponding values from the input [`TreeVec>`]. - pub fn zip_cols( - self, - other: impl Into>>, - ) -> TreeVec> { - let other = other.into(); - TreeVec( - zip_eq(self.0, other.0) - .map(|(column1, column2)| zip_eq(column1, column2).collect()) - .collect(), - ) - } - - pub fn as_cols_ref(&self) -> TreeVec> { - TreeVec(self.iter().map(|column| column.iter().collect()).collect()) - } - - /// Flattens the [`TreeVec>`] into a single [`ColumnVec`] with all the columns - /// combined. - pub fn flatten(self) -> ColumnVec { - self.0.into_iter().flatten().collect() - } - - /// Appends the columns of another [`TreeVec>`] to this one. - pub fn append_cols(&mut self, mut other: TreeVec>) { - let n_trees = self.0.len().max(other.0.len()); - self.0.resize_with(n_trees, Default::default); - for (self_col, other_col) in self.0.iter_mut().zip(other.0.iter_mut()) { - self_col.append(other_col); - } - } - - /// Concatenates the columns of multiple [`TreeVec>`] into a single - /// [`TreeVec>`]. - pub fn concat_cols( - trees: impl Iterator>>, - ) -> TreeVec> { - let mut result = TreeVec::default(); - for tree in trees { - result.append_cols(tree); - } - result - } - - /// Extracts a sub-tree based on the specified locations. - /// - /// # Panics - /// - /// If two or more locations have the same tree index. - pub fn sub_tree(&self, locations: &[TreeSubspan]) -> TreeVec> { - let tree_indicies: BTreeSet = locations.iter().map(|l| l.tree_index).collect(); - assert_eq!(tree_indicies.len(), locations.len()); - let max_tree_index = tree_indicies.iter().max().unwrap_or(&0); - let mut res = TreeVec(vec![Vec::new(); max_tree_index + 1]); - - for &location in locations { - // TODO(andrew): Throwing error here might be better instead. - let chunk = self.get_chunk(location).unwrap(); - res[location.tree_index] = chunk; - } - - res - } - - fn get_chunk(&self, location: TreeSubspan) -> Option> { - let tree = self.0.get(location.tree_index)?; - let chunk = tree.get(location.col_start..location.col_end)?; - Some(chunk.iter().collect()) - } -} - -impl<'a, T> From<&'a TreeVec>> for TreeVec> { - fn from(val: &'a TreeVec>) -> Self { - val.as_cols_ref() - } -} - -impl TreeVec>> { - /// Flattens a [`TreeVec>`] of [Vec]s into a single [Vec] with all the elements - /// combined. - pub fn flatten_cols(self) -> Vec { - self.0.into_iter().flatten().flatten().collect() - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/pcs/verifier.rs b/Stwo_wrapper/crates/prover/src/core/pcs/verifier.rs deleted file mode 100644 index da58a84..0000000 --- a/Stwo_wrapper/crates/prover/src/core/pcs/verifier.rs +++ /dev/null @@ -1,134 +0,0 @@ -use std::iter::zip; - -use itertools::Itertools; - -use super::super::circle::CirclePoint; -use super::super::fields::qm31::SecureField; -use super::super::fri::{CirclePolyDegreeBound, FriVerifier}; -use super::quotients::{fri_answers, PointSample}; -use super::utils::TreeVec; -use super::{CommitmentSchemeProof, PcsConfig}; -use crate::core::channel::{Channel, MerkleChannel}; -use crate::core::prover::VerificationError; -use crate::core::vcs::ops::MerkleHasher; -use crate::core::vcs::verifier::MerkleVerifier; -use crate::core::ColumnVec; - -/// The verifier side of a FRI polynomial commitment scheme. See [super]. -#[derive(Default)] -pub struct CommitmentSchemeVerifier { - pub trees: TreeVec>, - pub config: PcsConfig, -} - -impl CommitmentSchemeVerifier { - pub fn new(config: PcsConfig) -> Self { - Self { - trees: TreeVec::default(), - config, - } - } - - /// A [TreeVec] of the log sizes of each column in each commitment tree. - fn column_log_sizes(&self) -> TreeVec> { - self.trees - .as_ref() - .map(|tree| tree.column_log_sizes.clone()) - } - - /// Reads a commitment from the prover. - pub fn commit( - &mut self, - commitment: ::Hash, - log_sizes: &[u32], - channel: &mut MC::C, - ) { - MC::mix_root(channel, commitment); - let extended_log_sizes = log_sizes - .iter() - .map(|&log_size| log_size + self.config.fri_config.log_blowup_factor) - .collect(); - let verifier = MerkleVerifier::new(commitment, extended_log_sizes); - self.trees.push(verifier); - } - - pub fn verify_values( - &self, - sampled_points: TreeVec>>>, - proof: CommitmentSchemeProof, - channel: &mut MC::C, - ) -> Result<(), VerificationError> { - - channel.mix_felts(&proof.sampled_values.clone().flatten_cols()); - let random_coeff = channel.draw_felt(); - - let bounds = self - .column_log_sizes() - .zip_cols(&sampled_points) - .map_cols(|(log_size, sampled_points)| { - vec![ - CirclePolyDegreeBound::new(log_size - self.config.fri_config.log_blowup_factor); - sampled_points.len() - ] - }) - .flatten_cols() - .into_iter() - .sorted() - .rev() - .dedup() - .collect_vec(); - - // FRI commitment phase on OODS quotients. - let mut fri_verifier = - FriVerifier::::commit(channel, self.config.fri_config, proof.fri_proof, bounds)?; - - // Verify proof of work. - channel.mix_nonce(proof.proof_of_work); - if channel.trailing_zeros() < self.config.pow_bits { - return Err(VerificationError::ProofOfWork); - } - - // Get FRI query domains. - let fri_query_domains = fri_verifier.column_query_positions(channel); - - // Verify merkle decommitments. - self.trees - .as_ref() - .zip_eq(proof.decommitments) - .zip_eq(proof.queried_values.clone()) - .map(|((tree, decommitment), queried_values)| { - let queries = fri_query_domains - .iter() - .map(|(&log_size, domain)| (log_size, domain.flatten())) - .collect(); - tree.verify(queries, queried_values, decommitment) - }) - .0 - .into_iter() - .collect::>()?; - println!("DONE"); - - // Answer FRI queries. - let samples = sampled_points // Sample point is always the same and the value is the one provided in the proof - .zip_cols(proof.sampled_values) - .map_cols(|(sampled_points, sampled_values)| { - zip(sampled_points, sampled_values) - .map(|(point, value)| PointSample { point, value }) - .collect_vec() - }) - .flatten(); - - // TODO(spapini): Properly defined column log size and dinstinguish between poly and - // commitment. - let fri_answers = fri_answers( - self.column_log_sizes().flatten().into_iter().collect(), - &samples, - random_coeff, - fri_query_domains, - &proof.queried_values.flatten(), - )?; - - fri_verifier.decommit(fri_answers)?; - Ok(()) - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/poly/circle/canonic.rs b/Stwo_wrapper/crates/prover/src/core/poly/circle/canonic.rs deleted file mode 100644 index 837e648..0000000 --- a/Stwo_wrapper/crates/prover/src/core/poly/circle/canonic.rs +++ /dev/null @@ -1,77 +0,0 @@ -use super::CircleDomain; -use crate::core::circle::{CirclePoint, CirclePointIndex, Coset}; -use crate::core::fields::m31::BaseField; - -/// A coset of the form G_{2n} + , where G_n is the generator of the -/// subgroup of order n. The ordering on this coset is G_2n + i * G_n. -/// These cosets can be used as a [CircleDomain], and be interpolated on. -/// Note that this changes the ordering on the coset to be like [CircleDomain], -/// which is G_2n + i * G_n/2 and then -G_2n -i * G_n/2. -/// For example, the Xs below are a canonic coset with n=8. -/// ```text -/// X O X -/// O O -/// X X -/// O O -/// X X -/// O O -/// X O X -/// ``` -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub struct CanonicCoset { - pub coset: Coset, -} - -impl CanonicCoset { - pub fn new(log_size: u32) -> Self { - assert!(log_size > 0); - Self { - coset: Coset::odds(log_size), - } - } - - /// Gets the full coset represented G_{2n} + . - pub fn coset(&self) -> Coset { - self.coset - } - - /// Gets half of the coset (its conjugate complements to the whole coset), G_{2n} + - pub fn half_coset(&self) -> Coset { - Coset::half_odds(self.log_size() - 1) - } - - /// Gets the [CircleDomain] representing the same point set (in another order). - pub fn circle_domain(&self) -> CircleDomain { - CircleDomain::new(self.half_coset()) - } - - /// Returns the log size of the coset. - pub fn log_size(&self) -> u32 { - self.coset.log_size - } - - /// Returns the size of the coset. - pub fn size(&self) -> usize { - self.coset.size() - } - - pub fn initial_index(&self) -> CirclePointIndex { - self.coset.initial_index - } - - pub fn step_size(&self) -> CirclePointIndex { - self.coset.step_size - } - - pub fn step(&self) -> CirclePoint { - self.coset.step - } - - pub fn index_at(&self, index: usize) -> CirclePointIndex { - self.coset.index_at(index) - } - - pub fn at(&self, i: usize) -> CirclePoint { - self.coset.at(i) - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/poly/circle/domain.rs b/Stwo_wrapper/crates/prover/src/core/poly/circle/domain.rs deleted file mode 100644 index fba2bc3..0000000 --- a/Stwo_wrapper/crates/prover/src/core/poly/circle/domain.rs +++ /dev/null @@ -1,188 +0,0 @@ -use std::iter::Chain; - -use itertools::Itertools; - -use crate::core::circle::{ - CirclePoint, CirclePointIndex, Coset, CosetIterator, M31_CIRCLE_LOG_ORDER, -}; -use crate::core::fields::m31::BaseField; - -pub const MAX_CIRCLE_DOMAIN_LOG_SIZE: u32 = M31_CIRCLE_LOG_ORDER - 1; - -/// A valid domain for circle polynomial interpolation and evaluation. -/// Valid domains are a disjoint union of two conjugate cosets: +-C + . -/// The ordering defined on this domain is C + iG_n, and then -C - iG_n. -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub struct CircleDomain { - pub half_coset: Coset, -} - -impl CircleDomain { - /// Given a coset C + , constructs the circle domain +-C + (i.e., - /// this coset and its conjugate). - pub fn new(half_coset: Coset) -> Self { - Self { half_coset } - } - - pub fn iter(&self) -> CircleDomainIterator { - self.half_coset - .iter() - .chain(self.half_coset.conjugate().iter()) - } - - /// Iterates over point indices. - pub fn iter_indices(&self) -> CircleDomainIndexIterator { - self.half_coset - .iter_indices() - .chain(self.half_coset.conjugate().iter_indices()) - } - - /// Returns the size of the domain. - pub fn size(&self) -> usize { - 1 << self.log_size() - } - - /// Returns the log size of the domain. - pub fn log_size(&self) -> u32 { - self.half_coset.log_size + 1 - } - - /// Returns the `i` th domain element. - pub fn at(&self, i: usize) -> CirclePoint { - self.index_at(i).to_point() - } - - /// Returns the [CirclePointIndex] of the `i`th domain element. - pub fn index_at(&self, i: usize) -> CirclePointIndex { - if i < self.half_coset.size() { - self.half_coset.index_at(i) - } else { - -self.half_coset.index_at(i - self.half_coset.size()) - } - } - - pub fn find(&self, i: CirclePointIndex) -> Option { - if let Some(d) = self.half_coset.find(i) { - return Some(d); - } - if let Some(d) = self.half_coset.conjugate().find(i) { - return Some(self.half_coset.size() + d); - } - None - } - - /// Returns true if the domain is canonic. - /// - /// Canonic domains are domains with elements that are the entire set of points defined by - /// `G_2n + ` where `G_n` and `G_2n` are obtained by repeatedly doubling - /// [crate::core::circle::M31_CIRCLE_GEN]. - pub fn is_canonic(&self) -> bool { - self.half_coset.initial_index * 4 == self.half_coset.step_size - } - - /// Splits a circle domain into a smaller [CircleDomain]s, shifted by offsets. - pub fn split(&self, log_parts: u32) -> (CircleDomain, Vec) { - assert!(log_parts <= self.half_coset.log_size); - let subdomain = CircleDomain::new(Coset::new( - self.half_coset.initial_index, - self.half_coset.log_size - log_parts, - )); - let shifts = (0..1 << log_parts) - .map(|i| self.half_coset.step_size * i) - .collect_vec(); - (subdomain, shifts) - } - - pub fn shift(&self, shift: CirclePointIndex) -> CircleDomain { - CircleDomain::new(self.half_coset.shift(shift)) - } -} - -impl IntoIterator for CircleDomain { - type Item = CirclePoint; - type IntoIter = CircleDomainIterator; - - /// Iterates over the points in the domain. - fn into_iter(self) -> CircleDomainIterator { - self.iter() - } -} - -/// An iterator over points in a circle domain. -/// -/// Let the domain be `+-c + `. The first iterated points are `c + `, then `-c + <-G>`. -pub type CircleDomainIterator = - Chain>, CosetIterator>>; - -/// Like [CircleDomainIterator] but returns corresponding [CirclePointIndex]s. -type CircleDomainIndexIterator = - Chain, CosetIterator>; - -#[cfg(test)] -mod tests { - use itertools::Itertools; - - use super::CircleDomain; - use crate::core::circle::{CirclePointIndex, Coset}; - use crate::core::poly::circle::CanonicCoset; - - #[test] - fn test_circle_domain_iterator() { - let domain = CircleDomain::new(Coset::new(CirclePointIndex::generator(), 2)); - for (i, point) in domain.iter().enumerate() { - if i < 4 { - assert_eq!( - point, - (CirclePointIndex::generator() + CirclePointIndex::subgroup_gen(2) * i) - .to_point() - ); - } else { - assert_eq!( - point, - (-(CirclePointIndex::generator() + CirclePointIndex::subgroup_gen(2) * i)) - .to_point() - ); - } - } - } - - #[test] - fn is_canonic_invalid_domain() { - let half_coset = Coset::new(CirclePointIndex::generator(), 4); - let not_canonic_domain = CircleDomain::new(half_coset); - - assert!(!not_canonic_domain.is_canonic()); - } - - #[test] - fn test_at_circle_domain() { - let domain = CanonicCoset::new(7).circle_domain(); - let half_domain_size = domain.size() / 2; - - for i in 0..half_domain_size { - assert_eq!(domain.index_at(i), -domain.index_at(i + half_domain_size)); - assert_eq!(domain.at(i), domain.at(i + half_domain_size).conjugate()); - } - } - - #[test] - fn test_domain_split() { - let domain = CanonicCoset::new(5).circle_domain(); - let (subdomain, shifts) = domain.split(2); - - let domain_points = domain.iter().collect::>(); - let points_for_each_domain = shifts - .iter() - .map(|&shift| (subdomain.shift(shift)).iter().collect_vec()) - .collect::>(); - // Interleave the points from each subdomain. - let extended_points = (0..(1 << 3)) - .flat_map(|point_ind| { - (0..(1 << 2)) - .map(|shift_ind| points_for_each_domain[shift_ind][point_ind]) - .collect_vec() - }) - .collect_vec(); - assert_eq!(domain_points, extended_points); - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/poly/circle/evaluation.rs b/Stwo_wrapper/crates/prover/src/core/poly/circle/evaluation.rs deleted file mode 100644 index 4cf23b9..0000000 --- a/Stwo_wrapper/crates/prover/src/core/poly/circle/evaluation.rs +++ /dev/null @@ -1,218 +0,0 @@ -use std::marker::PhantomData; -use std::ops::{Deref, Index}; - -use educe::Educe; - -use super::{CanonicCoset, CircleDomain, CirclePoly, PolyOps}; -use crate::core::backend::cpu::CpuCircleEvaluation; -use crate::core::backend::{Col, Column}; -use crate::core::circle::{CirclePointIndex, Coset}; -use crate::core::fields::m31::BaseField; -use crate::core::fields::{ExtensionOf, FieldOps}; -use crate::core::poly::twiddles::TwiddleTree; -use crate::core::poly::{BitReversedOrder, NaturalOrder}; -use crate::core::utils::bit_reverse_index; - -/// An evaluation defined on a [CircleDomain]. -/// The values are ordered according to the [CircleDomain] ordering. -#[derive(Educe)] -#[educe(Clone, Debug)] -pub struct CircleEvaluation, F: ExtensionOf, EvalOrder = NaturalOrder> { - pub domain: CircleDomain, - pub values: Col, - _eval_order: PhantomData, -} - -impl, F: ExtensionOf, EvalOrder> CircleEvaluation { - pub fn new(domain: CircleDomain, values: Col) -> Self { - assert_eq!(domain.size(), values.len()); - Self { - domain, - values, - _eval_order: PhantomData, - } - } -} - -// Note: The concrete implementation of the poly operations is in the specific backend used. -// For example, the CPU backend implementation is in `src/core/backend/cpu/poly.rs`. -impl, B: FieldOps> CircleEvaluation { - // TODO(spapini): Remove. Is this even used. - pub fn get_at(&self, point_index: CirclePointIndex) -> F { - self.values - .at(self.domain.find(point_index).expect("Not in domain")) - } - - pub fn bit_reverse(mut self) -> CircleEvaluation { - B::bit_reverse_column(&mut self.values); - CircleEvaluation::new(self.domain, self.values) - } -} - -impl> CpuCircleEvaluation { - pub fn fetch_eval_on_coset(&self, coset: Coset) -> CosetSubEvaluation<'_, F> { - assert!(coset.log_size() <= self.domain.half_coset.log_size()); - if let Some(offset) = self.domain.half_coset.find(coset.initial_index) { - return CosetSubEvaluation::new( - &self.values[..self.domain.half_coset.size()], - offset, - coset.step_size / self.domain.half_coset.step_size, - ); - } - if let Some(offset) = self.domain.half_coset.conjugate().find(coset.initial_index) { - return CosetSubEvaluation::new( - &self.values[self.domain.half_coset.size()..], - offset, - (-coset.step_size) / self.domain.half_coset.step_size, - ); - } - panic!("Coset not found in domain"); - } -} - -impl CircleEvaluation { - /// Creates a [CircleEvaluation] from values ordered according to - /// [CanonicCoset]. For example, the canonic coset might look like this: - /// G_8, G_8 + G_4, G_8 + 2G_4, G_8 + 3G_4. - /// The circle domain will be ordered like this: - /// G_8, G_8 + 2G_4, -G_8, -G_8 - 2G_4. - pub fn new_canonical_ordered(coset: CanonicCoset, values: Col) -> Self { - B::new_canonical_ordered(coset, values) - } - - /// Computes a minimal [CirclePoly] that evaluates to the same values as this evaluation. - pub fn interpolate(self) -> CirclePoly { - let coset = self.domain.half_coset; - B::interpolate(self, &B::precompute_twiddles(coset)) - } - - /// Computes a minimal [CirclePoly] that evaluates to the same values as this evaluation, using - /// precomputed twiddles. - pub fn interpolate_with_twiddles(self, twiddles: &TwiddleTree) -> CirclePoly { - B::interpolate(self, twiddles) - } -} - -impl, F: ExtensionOf> CircleEvaluation { - pub fn bit_reverse(mut self) -> CircleEvaluation { - B::bit_reverse_column(&mut self.values); - CircleEvaluation::new(self.domain, self.values) - } - - pub fn get_at(&self, point_index: CirclePointIndex) -> F { - self.values.at(bit_reverse_index( - self.domain.find(point_index).expect("Not in domain"), - self.domain.log_size(), - )) - } -} - -impl, F: ExtensionOf, EvalOrder> Deref - for CircleEvaluation -{ - type Target = Col; - - fn deref(&self) -> &Self::Target { - &self.values - } -} - -/// A part of a [CircleEvaluation], for a specific coset that is a subset of the circle domain. -pub struct CosetSubEvaluation<'a, F: ExtensionOf> { - evaluation: &'a [F], - offset: usize, - step: isize, -} - -impl<'a, F: ExtensionOf> CosetSubEvaluation<'a, F> { - fn new(evaluation: &'a [F], offset: usize, step: isize) -> Self { - assert!(evaluation.len().is_power_of_two()); - Self { - evaluation, - offset, - step, - } - } -} - -impl<'a, F: ExtensionOf> Index for CosetSubEvaluation<'a, F> { - type Output = F; - - fn index(&self, index: isize) -> &Self::Output { - let index = - ((self.offset as isize) + index * self.step) & ((self.evaluation.len() - 1) as isize); - &self.evaluation[index as usize] - } -} - -impl<'a, F: ExtensionOf> Index for CosetSubEvaluation<'a, F> { - type Output = F; - - fn index(&self, index: usize) -> &Self::Output { - &self[index as isize] - } -} - -#[cfg(test)] -mod tests { - use crate::core::backend::cpu::CpuCircleEvaluation; - use crate::core::circle::Coset; - use crate::core::fields::m31::BaseField; - use crate::core::poly::circle::CanonicCoset; - use crate::core::poly::NaturalOrder; - use crate::m31; - - #[test] - fn test_interpolate_non_canonic() { - let domain = CanonicCoset::new(3).circle_domain(); - assert_eq!(domain.log_size(), 3); - let evaluation = CpuCircleEvaluation::<_, NaturalOrder>::new( - domain, - (0..8).map(BaseField::from_u32_unchecked).collect(), - ) - .bit_reverse(); - let poly = evaluation.interpolate(); - for (i, point) in domain.iter().enumerate() { - assert_eq!(poly.eval_at_point(point.into_ef()), m31!(i as u32).into()); - } - } - - #[test] - fn test_interpolate_canonic() { - let coset = CanonicCoset::new(3); - let evaluation = CpuCircleEvaluation::new_canonical_ordered( - coset, - (0..8).map(BaseField::from_u32_unchecked).collect(), - ); - let poly = evaluation.interpolate(); - for (i, point) in Coset::odds(3).iter().enumerate() { - assert_eq!(poly.eval_at_point(point.into_ef()), m31!(i as u32).into()); - } - } - - #[test] - pub fn test_get_at_circle_evaluation() { - let domain = CanonicCoset::new(7).circle_domain(); - let values = (0..domain.size()).map(|i| m31!(i as u32)).collect(); - let circle_evaluation = CpuCircleEvaluation::<_, NaturalOrder>::new(domain, values); - let bit_reversed_circle_evaluation = circle_evaluation.clone().bit_reverse(); - for index in domain.iter_indices() { - assert_eq!( - circle_evaluation.get_at(index), - bit_reversed_circle_evaluation.get_at(index) - ); - } - } - - #[test] - fn test_sub_evaluation() { - let domain = CanonicCoset::new(7).circle_domain(); - let values = (0..domain.size()).map(|i| m31!(i as u32)).collect(); - let circle_evaluation = CpuCircleEvaluation::new(domain, values); - let coset = Coset::new(domain.index_at(17), 3); - let sub_eval = circle_evaluation.fetch_eval_on_coset(coset); - for i in 0..coset.size() { - assert_eq!(sub_eval[i], circle_evaluation.get_at(coset.index_at(i))); - } - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/poly/circle/mod.rs b/Stwo_wrapper/crates/prover/src/core/poly/circle/mod.rs deleted file mode 100644 index f2532d5..0000000 --- a/Stwo_wrapper/crates/prover/src/core/poly/circle/mod.rs +++ /dev/null @@ -1,56 +0,0 @@ -mod canonic; -mod domain; -mod evaluation; -mod ops; -mod poly; -mod secure_poly; - -pub use canonic::CanonicCoset; -pub use domain::{CircleDomain, MAX_CIRCLE_DOMAIN_LOG_SIZE}; -pub use evaluation::{CircleEvaluation, CosetSubEvaluation}; -pub use ops::PolyOps; -pub use poly::CirclePoly; -pub use secure_poly::{SecureCirclePoly, SecureEvaluation}; - -#[cfg(test)] -mod tests { - use super::CanonicCoset; - use crate::core::backend::cpu::CpuCircleEvaluation; - use crate::core::fields::m31::BaseField; - use crate::core::utils::bit_reverse_index; - - #[test] - fn test_interpolate_and_eval() { - let domain = CanonicCoset::new(3).circle_domain(); - assert_eq!(domain.log_size(), 3); - let evaluation = - CpuCircleEvaluation::new(domain, (0..8).map(BaseField::from_u32_unchecked).collect()); - let poly = evaluation.clone().interpolate(); - let evaluation2 = poly.evaluate(domain); - assert_eq!(evaluation.values, evaluation2.values); - } - - #[test] - fn is_canonic_valid_domain() { - let canonic_domain = CanonicCoset::new(4).circle_domain(); - - assert!(canonic_domain.is_canonic()); - } - - #[test] - pub fn test_bit_reverse_indices() { - let log_domain_size = 7; - let log_small_domain_size = 5; - let domain = CanonicCoset::new(log_domain_size); - let small_domain = CanonicCoset::new(log_small_domain_size); - let n_folds = log_domain_size - log_small_domain_size; - for i in 0..2usize.pow(log_domain_size) { - let point = domain.at(bit_reverse_index(i, log_domain_size)); - let small_point = small_domain.at(bit_reverse_index( - i / 2usize.pow(n_folds), - log_small_domain_size, - )); - assert_eq!(point.repeated_double(n_folds), small_point); - } - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/poly/circle/ops.rs b/Stwo_wrapper/crates/prover/src/core/poly/circle/ops.rs deleted file mode 100644 index 40b86cb..0000000 --- a/Stwo_wrapper/crates/prover/src/core/poly/circle/ops.rs +++ /dev/null @@ -1,48 +0,0 @@ -use super::{CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly}; -use crate::core::backend::Col; -use crate::core::circle::{CirclePoint, Coset}; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::FieldOps; -use crate::core::poly::twiddles::TwiddleTree; -use crate::core::poly::BitReversedOrder; - -/// Operations on BaseField polynomials. -pub trait PolyOps: FieldOps + Sized { - // TODO(spapini): Use a column instead of this type. - /// The type for precomputed twiddles. - type Twiddles; - - /// Creates a [CircleEvaluation] from values ordered according to [CanonicCoset]. - /// Used by the [`CircleEvaluation::new_canonical_ordered()`] function. - fn new_canonical_ordered( - coset: CanonicCoset, - values: Col, - ) -> CircleEvaluation; - - /// Computes a minimal [CirclePoly] that evaluates to the same values as this evaluation. - /// Used by the [`CircleEvaluation::interpolate()`] function. - fn interpolate( - eval: CircleEvaluation, - itwiddles: &TwiddleTree, - ) -> CirclePoly; - - /// Evaluates the polynomial at a single point. - /// Used by the [`CirclePoly::eval_at_point()`] function. - fn eval_at_point(poly: &CirclePoly, point: CirclePoint) -> SecureField; - - /// Extends the polynomial to a larger degree bound. - /// Used by the [`CirclePoly::extend()`] function. - fn extend(poly: &CirclePoly, log_size: u32) -> CirclePoly; - - /// Evaluates the polynomial at all points in the domain. - /// Used by the [`CirclePoly::evaluate()`] function. - fn evaluate( - poly: &CirclePoly, - domain: CircleDomain, - twiddles: &TwiddleTree, - ) -> CircleEvaluation; - - /// Precomputes twiddles for a given coset. - fn precompute_twiddles(coset: Coset) -> TwiddleTree; -} diff --git a/Stwo_wrapper/crates/prover/src/core/poly/circle/poly.rs b/Stwo_wrapper/crates/prover/src/core/poly/circle/poly.rs deleted file mode 100644 index c10fc5e..0000000 --- a/Stwo_wrapper/crates/prover/src/core/poly/circle/poly.rs +++ /dev/null @@ -1,118 +0,0 @@ -use super::{CircleDomain, CircleEvaluation, PolyOps}; -use crate::core::backend::{Col, Column}; -use crate::core::circle::CirclePoint; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::FieldOps; -use crate::core::poly::twiddles::TwiddleTree; -use crate::core::poly::BitReversedOrder; - -/// A polynomial defined on a [CircleDomain]. -#[derive(Clone, Debug)] -pub struct CirclePoly> { - /// Coefficients of the polynomial in the FFT basis. - /// Note: These are not the coefficients of the polynomial in the standard - /// monomial basis. The FFT basis is a tensor product of the twiddles: - /// y, x, pi(x), pi^2(x), ..., pi^{log_size-2}(x). - /// pi(x) := 2x^2 - 1. - pub coeffs: Col, - /// The number of coefficients stored as `log2(len(coeffs))`. - log_size: u32, -} - -impl CirclePoly { - /// Creates a new circle polynomial. - /// - /// Coefficients must be in the circle IFFT algorithm's basis stored in bit-reversed order. - /// - /// # Panics - /// - /// Panics if the number of coefficients isn't a power of two. - pub fn new(coeffs: Col) -> Self { - assert!(coeffs.len().is_power_of_two()); - let log_size = coeffs.len().ilog2(); - Self { log_size, coeffs } - } - - pub fn log_size(&self) -> u32 { - self.log_size - } - - /// Evaluates the polynomial at a single point. - pub fn eval_at_point(&self, point: CirclePoint) -> SecureField { - B::eval_at_point(self, point) - } - - /// Extends the polynomial to a larger degree bound. - pub fn extend(&self, log_size: u32) -> Self { - B::extend(self, log_size) - } - - /// Evaluates the polynomial at all points in the domain. - pub fn evaluate( - &self, - domain: CircleDomain, - ) -> CircleEvaluation { - B::evaluate(self, domain, &B::precompute_twiddles(domain.half_coset)) - } - - /// Evaluates the polynomial at all points in the domain, using precomputed twiddles. - pub fn evaluate_with_twiddles( - &self, - domain: CircleDomain, - twiddles: &TwiddleTree, - ) -> CircleEvaluation { - B::evaluate(self, domain, twiddles) - } -} - -#[cfg(test)] -impl crate::core::backend::cpu::CpuCirclePoly { - pub fn is_in_fft_space(&self, log_fft_size: u32) -> bool { - use num_traits::Zero; - - let mut coeffs = self.coeffs.clone(); - while coeffs.last() == Some(&BaseField::zero()) { - coeffs.pop(); - } - - // The highest degree monomial in a fft-space polynomial is x^{(n/2) - 1}y. - // And it is at offset (n-1). x^{(n/2)} is at offset `n`, and is not allowed. - let highest_degree_allowed_monomial_offset = 1 << log_fft_size; - coeffs.len() <= highest_degree_allowed_monomial_offset - } - - /// Fri space is the space of polynomials of total degree n/2. - /// Highest degree monomials are x^{n/2} and x^{(n/2)-1}y. - pub fn is_in_fri_space(&self, log_fft_size: u32) -> bool { - use num_traits::Zero; - - let mut coeffs = self.coeffs.clone(); - while coeffs.last() == Some(&BaseField::zero()) { - coeffs.pop(); - } - - // x^{n/2} is at offset `n`, and is the last offset allowed to be non-zero. - let highest_degree_monomial_offset = (1 << log_fft_size) + 1; - coeffs.len() <= highest_degree_monomial_offset - } -} - -#[cfg(test)] -mod tests { - use crate::core::backend::cpu::CpuCirclePoly; - use crate::core::circle::CirclePoint; - use crate::core::fields::m31::BaseField; - - #[test] - fn test_circle_poly_extend() { - let poly = CpuCirclePoly::new((0..16).map(BaseField::from_u32_unchecked).collect()); - let extended = poly.clone().extend(8); - let random_point = CirclePoint::get_point(21903); - - assert_eq!( - poly.eval_at_point(random_point), - extended.eval_at_point(random_point) - ); - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/poly/circle/secure_poly.rs b/Stwo_wrapper/crates/prover/src/core/poly/circle/secure_poly.rs deleted file mode 100644 index a503bd2..0000000 --- a/Stwo_wrapper/crates/prover/src/core/poly/circle/secure_poly.rs +++ /dev/null @@ -1,118 +0,0 @@ -use std::marker::PhantomData; -use std::ops::{Deref, DerefMut}; - -use super::{CircleDomain, CircleEvaluation, CirclePoly, PolyOps}; -use crate::core::backend::CpuBackend; -use crate::core::circle::CirclePoint; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::secure_column::{SecureColumnByCoords, SECURE_EXTENSION_DEGREE}; -use crate::core::fields::FieldOps; -use crate::core::poly::twiddles::TwiddleTree; -use crate::core::poly::BitReversedOrder; - -pub struct SecureCirclePoly>(pub [CirclePoly; SECURE_EXTENSION_DEGREE]); - -impl SecureCirclePoly { - pub fn eval_at_point(&self, point: CirclePoint) -> SecureField { - SecureField::from_partial_evals(self.eval_columns_at_point(point)) - } - - pub fn eval_columns_at_point( - &self, - point: CirclePoint, - ) -> [SecureField; SECURE_EXTENSION_DEGREE] { - [ - self[0].eval_at_point(point), - self[1].eval_at_point(point), - self[2].eval_at_point(point), - self[3].eval_at_point(point), - ] - } - - pub fn log_size(&self) -> u32 { - self[0].log_size() - } - - pub fn evaluate_with_twiddles( - &self, - domain: CircleDomain, - twiddles: &TwiddleTree, - ) -> SecureEvaluation { - let polys = self.0.each_ref(); - let columns = polys.map(|poly| poly.evaluate_with_twiddles(domain, twiddles).values); - SecureEvaluation::new(domain, SecureColumnByCoords { columns }) - } -} - -impl> Deref for SecureCirclePoly { - type Target = [CirclePoly; SECURE_EXTENSION_DEGREE]; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -/// A [`SecureField`] evaluation defined on a [CircleDomain]. -/// -/// The evaluation is stored as a column major array of [`SECURE_EXTENSION_DEGREE`] many base field -/// evaluations. The evaluations are ordered according to the [CircleDomain] ordering. -#[derive(Clone)] -pub struct SecureEvaluation, EvalOrder> { - pub domain: CircleDomain, - pub values: SecureColumnByCoords, - _eval_order: PhantomData, -} - -impl, EvalOrder> SecureEvaluation { - pub fn new(domain: CircleDomain, values: SecureColumnByCoords) -> Self { - assert_eq!(domain.size(), values.len()); - Self { - domain, - values, - _eval_order: PhantomData, - } - } - - pub fn into_coordinate_evals( - self, - ) -> [CircleEvaluation; SECURE_EXTENSION_DEGREE] { - let Self { domain, values, .. } = self; - values.columns.map(|c| CircleEvaluation::new(domain, c)) - } -} - -impl, EvalOrder> Deref for SecureEvaluation { - type Target = SecureColumnByCoords; - - fn deref(&self) -> &Self::Target { - &self.values - } -} - -impl, EvalOrder> DerefMut for SecureEvaluation { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.values - } -} - -impl SecureEvaluation { - /// Computes a minimal [`SecureCirclePoly`] that evaluates to the same values as this - /// evaluation, using precomputed twiddles. - pub fn interpolate_with_twiddles(self, twiddles: &TwiddleTree) -> SecureCirclePoly { - let domain = self.domain; - let cols = self.values.columns; - SecureCirclePoly(cols.map(|c| { - CircleEvaluation::::new(domain, c) - .interpolate_with_twiddles(twiddles) - })) - } -} - -impl From> - for SecureEvaluation -{ - fn from(evaluation: CircleEvaluation) -> Self { - Self::new(evaluation.domain, evaluation.values.into_iter().collect()) - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/poly/line.rs b/Stwo_wrapper/crates/prover/src/core/poly/line.rs deleted file mode 100644 index 4dac73e..0000000 --- a/Stwo_wrapper/crates/prover/src/core/poly/line.rs +++ /dev/null @@ -1,408 +0,0 @@ -use std::cmp::Ordering; -use std::fmt::Debug; -use std::iter::Map; -use std::ops::{Deref, DerefMut}; - -use itertools::Itertools; -use num_traits::Zero; -use serde::{Deserialize, Serialize}; - -use super::circle::CircleDomain; -use super::utils::fold; -use crate::core::backend::{ColumnOps, CpuBackend}; -use crate::core::circle::{CirclePoint, Coset, CosetIterator}; -use crate::core::fft::ibutterfly; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::secure_column::SecureColumnByCoords; -use crate::core::fields::{ExtensionOf, FieldExpOps, FieldOps}; -use crate::core::utils::bit_reverse; - -/// Domain comprising of the x-coordinates of points in a [Coset]. -/// -/// For use with univariate polynomials. -#[derive(Copy, Clone, Debug)] -pub struct LineDomain { - coset: Coset, -} - -impl LineDomain { - /// Returns a domain comprising of the x-coordinates of points in a coset. - /// - /// # Panics - /// - /// Panics if the coset items don't have unique x-coordinates. - pub fn new(coset: Coset) -> Self { - match coset.size().cmp(&2) { - Ordering::Less => {} - Ordering::Equal => { - // If the coset with two points contains (0, y) then the coset is {(0, y), (0, -y)}. - assert!(!coset.initial.x.is_zero(), "coset x-coordinates not unique"); - } - Ordering::Greater => { - // Let our coset be `E = c + ` with `|E| > 2` then: - // 1. if `ord(c) <= ord(G)` the coset contains two points at x=0 - // 2. if `ord(c) = 2 * ord(G)` then `c` and `-c` are in our coset - assert!( - coset.initial.log_order() >= coset.step.log_order() + 2, - "coset x-coordinates not unique" - ); - } - } - Self { coset } - } - - /// Returns the `i`th domain element. - pub fn at(&self, i: usize) -> BaseField { - self.coset.at(i).x - } - - /// Returns the size of the domain. - pub fn size(&self) -> usize { - self.coset.size() - } - - /// Returns the log size of the domain. - pub fn log_size(&self) -> u32 { - self.coset.log_size() - } - - /// Returns an iterator over elements in the domain. - pub fn iter(&self) -> LineDomainIterator { - self.coset.iter().map(|p| p.x) - } - - /// Returns a new domain comprising of all points in current domain doubled. - pub fn double(&self) -> Self { - Self { - coset: self.coset.double(), - } - } - - /// Returns the domain's underlying coset. - pub fn coset(&self) -> Coset { - self.coset - } -} - -impl IntoIterator for LineDomain { - type Item = BaseField; - type IntoIter = LineDomainIterator; - - /// Returns an iterator over elements in the domain. - fn into_iter(self) -> LineDomainIterator { - self.iter() - } -} - -impl From for LineDomain { - fn from(domain: CircleDomain) -> Self { - Self { - coset: domain.half_coset, - } - } -} - -/// An iterator over the x-coordinates of points in a coset. -type LineDomainIterator = - Map>, fn(CirclePoint) -> BaseField>; - -/// A univariate polynomial defined on a [LineDomain]. -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize)] -pub struct LinePoly { - /// Coefficients of the polynomial in [line_ifft] algorithm's basis. - /// - /// The coefficients are stored in bit-reversed order. - pub coeffs: Vec, - /// The number of coefficients stored as `log2(len(coeffs))`. - log_size: u32, -} - -impl LinePoly { - /// Creates a new line polynomial from bit reversed coefficients. - /// - /// # Panics - /// - /// Panics if the number of coefficients is not a power of two. - pub fn new(coeffs: Vec) -> Self { - assert!(coeffs.len().is_power_of_two()); - let log_size = coeffs.len().ilog2(); - Self { coeffs, log_size } - } - - /// Evaluates the polynomial at a single point. - pub fn eval_at_point(&self, mut x: SecureField) -> SecureField { - let mut doublings = Vec::new(); - for _ in 0..self.log_size { - doublings.push(x); - x = CirclePoint::double_x(x); - } - fold(&self.coeffs, &doublings) - } - - /// Returns the number of coefficients. - #[allow(clippy::len_without_is_empty)] - pub fn len(&self) -> usize { - // `.len().ilog2()` is a common operation. By returning the length like so the compiler - // optimizes `.len().ilog2()` to a load of `log_size` instead of a branch and a bit count. - debug_assert_eq!(self.coeffs.len(), 1 << self.log_size); - 1 << self.log_size - } - - /// Returns the polynomial's coefficients in their natural order. - pub fn into_ordered_coefficients(mut self) -> Vec { - bit_reverse(&mut self.coeffs); - self.coeffs - } - - /// Creates a new line polynomial from coefficients in their natural order. - /// - /// # Panics - /// - /// Panics if the number of coefficients is not a power of two. - pub fn from_ordered_coefficients(mut coeffs: Vec) -> Self { - bit_reverse(&mut coeffs); - Self::new(coeffs) - } -} - -impl Deref for LinePoly { - type Target = [SecureField]; - - fn deref(&self) -> &[SecureField] { - &self.coeffs - } -} - -impl DerefMut for LinePoly { - fn deref_mut(&mut self) -> &mut [SecureField] { - &mut self.coeffs - } -} - -/// Evaluations of a univariate polynomial on a [LineDomain]. -// TODO(andrew): Remove EvalOrder. Bit-reversed evals are only necessary since LineEvaluation is -// only used by FRI where evaluations are in bit-reversed order. -// TODO(spapini): Remove pub. -#[derive(Clone, Debug)] -pub struct LineEvaluation> { - /// Evaluations of a univariate polynomial on `domain`. - pub values: SecureColumnByCoords, - domain: LineDomain, -} - -impl> LineEvaluation { - /// Creates new [LineEvaluation] from a set of polynomial evaluations over a [LineDomain]. - /// - /// # Panics - /// - /// Panics if the number of evaluations does not match the size of the domain. - pub fn new(domain: LineDomain, values: SecureColumnByCoords) -> Self { - assert_eq!(values.len(), domain.size()); - Self { values, domain } - } - - pub fn new_zero(domain: LineDomain) -> Self { - Self::new(domain, SecureColumnByCoords::zeros(domain.size())) - } - - /// Returns the number of evaluations. - #[allow(clippy::len_without_is_empty)] - pub fn len(&self) -> usize { - 1 << self.domain.log_size() - } - - pub fn domain(&self) -> LineDomain { - self.domain - } - - /// Clones the values into a new line evaluation in the CPU. - pub fn to_cpu(&self) -> LineEvaluation { - LineEvaluation::new(self.domain, self.values.to_cpu()) - } -} - -impl LineEvaluation { - /// Interpolates the polynomial as evaluations on `domain`. - pub fn interpolate(self) -> LinePoly { - let mut values = self.values.into_iter().collect_vec(); - CpuBackend::bit_reverse_column(&mut values); - line_ifft(&mut values, self.domain); - // Normalize the coefficients. - let len_inv = BaseField::from(values.len()).inverse(); - values.iter_mut().for_each(|v| *v *= len_inv); - LinePoly::new(values) - } -} - -/// Performs a univariate IFFT on a polynomial's evaluation over a [LineDomain]. -/// -/// This is not the standard univariate IFFT, because [LineDomain] is not a cyclic group. -/// -/// The transform happens in-place. `values` should be the evaluations of a polynomial over `domain` -/// in their natural order. After the transformation `values` becomes the coefficients of the -/// polynomial stored in bit-reversed order. -/// -/// For performance reasons and flexibility the normalization of the coefficients is omitted. The -/// normalized coefficients can be obtained by scaling all coefficients by `1 / len(values)`. -/// -/// This algorithm does not return coefficients in the standard monomial basis but rather returns -/// coefficients in a basis relating to the circle's x-coordinate doubling map `pi(x) = 2x^2 - 1` -/// i.e. -/// -/// ```text -/// B = { 1 } * { x } * { pi(x) } * { pi(pi(x)) } * ... -/// = { 1, x, pi(x), pi(x) * x, pi(pi(x)), pi(pi(x)) * x, pi(pi(x)) * pi(x), ... } -/// ``` -/// -/// # Panics -/// -/// Panics if the number of values doesn't match the size of the domain. -fn line_ifft>(values: &mut [F], mut domain: LineDomain) { - assert_eq!(values.len(), domain.size()); - while domain.size() > 1 { - for chunk in values.chunks_exact_mut(domain.size()) { - let (l, r) = chunk.split_at_mut(domain.size() / 2); - for (i, x) in domain.iter().take(domain.size() / 2).enumerate() { - ibutterfly(&mut l[i], &mut r[i], x.inverse()); - } - } - domain = domain.double(); - } -} - -#[cfg(test)] -mod tests { - type B = CpuBackend; - - use itertools::Itertools; - - use super::LineDomain; - use crate::core::backend::{ColumnOps, CpuBackend}; - use crate::core::circle::{CirclePoint, Coset}; - use crate::core::fields::m31::BaseField; - use crate::core::poly::line::{LineEvaluation, LinePoly}; - use crate::core::utils::bit_reverse_index; - - #[test] - #[should_panic] - fn bad_line_domain() { - // This coset doesn't have points with unique x-coordinates. - let coset = Coset::odds(2); - - LineDomain::new(coset); - } - - #[test] - fn line_domain_of_size_two_works() { - const LOG_SIZE: u32 = 1; - let coset = Coset::subgroup(LOG_SIZE); - - LineDomain::new(coset); - } - - #[test] - fn line_domain_of_size_one_works() { - const LOG_SIZE: u32 = 0; - let coset = Coset::subgroup(LOG_SIZE); - - LineDomain::new(coset); - } - - #[test] - fn line_domain_size_is_correct() { - const LOG_SIZE: u32 = 8; - let coset = Coset::half_odds(LOG_SIZE); - let domain = LineDomain::new(coset); - - let size = domain.size(); - - assert_eq!(size, 1 << LOG_SIZE); - } - - #[test] - fn line_domain_coset_returns_the_coset() { - let coset = Coset::half_odds(5); - let domain = LineDomain::new(coset); - - assert_eq!(domain.coset(), coset); - } - - #[test] - fn line_domain_double_works() { - const LOG_SIZE: u32 = 8; - let coset = Coset::half_odds(LOG_SIZE); - let domain = LineDomain::new(coset); - - let doubled_domain = domain.double(); - - assert_eq!(doubled_domain.size(), 1 << (LOG_SIZE - 1)); - assert_eq!(doubled_domain.at(0), CirclePoint::double_x(domain.at(0))); - assert_eq!(doubled_domain.at(1), CirclePoint::double_x(domain.at(1))); - } - - #[test] - fn line_domain_iter_works() { - const LOG_SIZE: u32 = 8; - let coset = Coset::half_odds(LOG_SIZE); - let domain = LineDomain::new(coset); - - let elements = domain.iter().collect::>(); - - assert_eq!(elements.len(), domain.size()); - for (i, element) in elements.into_iter().enumerate() { - assert_eq!(element, domain.at(i), "mismatch at {i}"); - } - } - - #[test] - fn line_evaluation_interpolation() { - let poly = LinePoly::new(vec![ - BaseField::from(7).into(), // 7 * 1 - BaseField::from(9).into(), // 9 * pi(x) - BaseField::from(5).into(), // 5 * x - BaseField::from(3).into(), // 3 * pi(x)*x - ]); - let coset = Coset::half_odds(poly.len().ilog2()); - let domain = LineDomain::new(coset); - let mut values = domain - .iter() - .map(|x| { - let pi_x = CirclePoint::double_x(x); - poly.coeffs[0] - + poly.coeffs[1] * pi_x - + poly.coeffs[2] * x - + poly.coeffs[3] * pi_x * x - }) - .collect_vec(); - CpuBackend::bit_reverse_column(&mut values); - let evals = LineEvaluation::::new(domain, values.into_iter().collect()); - - let interpolated_poly = evals.interpolate(); - - assert_eq!(interpolated_poly.coeffs, poly.coeffs); - } - - #[test] - fn line_polynomial_eval_at_point() { - const LOG_SIZE: u32 = 2; - let coset = Coset::half_odds(LOG_SIZE); - let domain = LineDomain::new(coset); - let evals = LineEvaluation::::new( - domain, - (0..1 << LOG_SIZE) - .map(BaseField::from) - .map(|x| x.into()) - .collect(), - ); - let poly = evals.clone().interpolate(); - - for (i, x) in domain.iter().enumerate() { - assert_eq!( - poly.eval_at_point(x.into()), - evals.values.at(bit_reverse_index(i, domain.log_size())), - "mismatch at {i}" - ); - } - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/poly/mod.rs b/Stwo_wrapper/crates/prover/src/core/poly/mod.rs deleted file mode 100644 index 301c698..0000000 --- a/Stwo_wrapper/crates/prover/src/core/poly/mod.rs +++ /dev/null @@ -1,14 +0,0 @@ -pub mod circle; -pub mod line; -// TODO(spapini): Remove pub, when LinePoly moved to the backend as well, and we can move the fold -// function there. -pub mod twiddles; -pub mod utils; - -/// Bit-reversed evaluation ordering. -#[derive(Copy, Clone, Debug)] -pub struct BitReversedOrder; - -/// Natural evaluation ordering (same order as domain). -#[derive(Copy, Clone, Debug)] -pub struct NaturalOrder; diff --git a/Stwo_wrapper/crates/prover/src/core/poly/twiddles.rs b/Stwo_wrapper/crates/prover/src/core/poly/twiddles.rs deleted file mode 100644 index 53ea476..0000000 --- a/Stwo_wrapper/crates/prover/src/core/poly/twiddles.rs +++ /dev/null @@ -1,13 +0,0 @@ -use super::circle::PolyOps; -use crate::core::circle::Coset; - -/// Precomputed twiddles for a specific coset tower. -/// A coset tower is every repeated doubling of a `root_coset`. -/// The largest CircleDomain that can be ffted using these twiddles is one with `root_coset` as -/// its `half_coset`. -pub struct TwiddleTree { - pub root_coset: Coset, - // TODO(spapini): Represent a slice, and grabbing, in a generic way - pub twiddles: B::Twiddles, - pub itwiddles: B::Twiddles, -} diff --git a/Stwo_wrapper/crates/prover/src/core/poly/utils.rs b/Stwo_wrapper/crates/prover/src/core/poly/utils.rs deleted file mode 100644 index bc0dece..0000000 --- a/Stwo_wrapper/crates/prover/src/core/poly/utils.rs +++ /dev/null @@ -1,115 +0,0 @@ -use super::line::LineDomain; -use crate::core::fields::{ExtensionOf, Field}; - -/// Folds values recursively in `O(n)` by a hierarchical application of folding factors. -/// -/// i.e. folding `n = 8` values with `folding_factors = [x, y, z]`: -/// -/// ```text -/// n2=n1+x*n2 -/// / \ -/// n1=n3+y*n4 n2=n5+y*n6 -/// / \ / \ -/// n3=a+z*b n4=c+z*d n5=e+z*f n6=g+z*h -/// / \ / \ / \ / \ -/// a b c d e f g h -/// ``` -/// -/// # Panics -/// -/// Panics if the number of values is not a power of two or if an incorrect number of of folding -/// factors is provided. -// TODO(Andrew): Can be made to run >10x faster by unrolling lower layers of recursion -pub fn fold>(values: &[F], folding_factors: &[E]) -> E { - let n = values.len(); - assert_eq!(n, 1 << folding_factors.len()); - if n == 1 { - return values[0].into(); - } - let (lhs_values, rhs_values) = values.split_at(n / 2); - let (&folding_factor, folding_factors) = folding_factors.split_first().unwrap(); - let lhs_val = fold(lhs_values, folding_factors); - let rhs_val = fold(rhs_values, folding_factors); - lhs_val + rhs_val * folding_factor -} - -/// Repeats each value sequentially `duplicity` many times. -/// -/// # Examples -/// -/// ```rust -/// # use stwo_prover::core::poly::utils::repeat_value; -/// assert_eq!(repeat_value(&[1, 2, 3], 2), vec![1, 1, 2, 2, 3, 3]); -/// ``` -pub fn repeat_value(values: &[T], duplicity: usize) -> Vec { - let n = values.len(); - let mut res: Vec = Vec::with_capacity(n * duplicity); - - // Fill each chunk with its corresponding value. - for &v in values { - for _ in 0..duplicity { - res.push(v) - } - } - - res -} - -/// Computes the line twiddles for a [`CircleDomain`] or a [`LineDomain`] from the precomputed -/// twiddles tree. -/// -/// [`CircleDomain`]: super::circle::CircleDomain -pub fn domain_line_twiddles_from_tree( - domain: impl Into, - twiddle_buffer: &[T], -) -> Vec<&[T]> { - let domain = domain.into(); - debug_assert!(domain.coset().size() <= twiddle_buffer.len()); - (0..domain.coset().log_size()) - .map(|i| { - let len = 1 << i; - &twiddle_buffer[twiddle_buffer.len() - len * 2..twiddle_buffer.len() - len] - }) - .rev() - .collect() -} - -#[cfg(test)] -mod tests { - use super::repeat_value; - use crate::core::poly::circle::CanonicCoset; - use crate::core::poly::line::LineDomain; - use crate::core::poly::utils::domain_line_twiddles_from_tree; - - #[test] - fn repeat_value_0_times_works() { - assert!(repeat_value(&[1, 2, 3], 0).is_empty()); - } - - #[test] - fn repeat_value_2_times_works() { - assert_eq!(repeat_value(&[1, 2, 3], 2), vec![1, 1, 2, 2, 3, 3]); - } - - #[test] - fn repeat_value_3_times_works() { - assert_eq!(repeat_value(&[1, 2], 3), vec![1, 1, 1, 2, 2, 2]); - } - - #[test] - fn test_domain_line_twiddles_works() { - let domain: LineDomain = CanonicCoset::new(4).circle_domain().into(); - let twiddles = domain_line_twiddles_from_tree(domain, &[0, 1, 2, 3, 4, 5, 6, 7]); - assert_eq!(twiddles.len(), 3); - assert_eq!(twiddles[0], &[0, 1, 2, 3]); - assert_eq!(twiddles[1], &[4, 5]); - assert_eq!(twiddles[2], &[6]); - } - - #[test] - #[should_panic] - fn test_domain_line_twiddles_fails() { - let domain: LineDomain = CanonicCoset::new(5).circle_domain().into(); - domain_line_twiddles_from_tree(domain, &[0, 1, 2, 3, 4, 5, 6, 7]); - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/proof_of_work.rs b/Stwo_wrapper/crates/prover/src/core/proof_of_work.rs deleted file mode 100644 index 1c61ad8..0000000 --- a/Stwo_wrapper/crates/prover/src/core/proof_of_work.rs +++ /dev/null @@ -1,7 +0,0 @@ -use crate::core::channel::Channel; - -pub trait GrindOps { - /// Searches for a nonce s.t. mixing it to the channel makes the digest have `pow_bits` leading - /// zero bits. - fn grind(channel: &C, pow_bits: u32) -> u64; -} diff --git a/Stwo_wrapper/crates/prover/src/core/prover/mod.rs b/Stwo_wrapper/crates/prover/src/core/prover/mod.rs deleted file mode 100644 index 30493fc..0000000 --- a/Stwo_wrapper/crates/prover/src/core/prover/mod.rs +++ /dev/null @@ -1,186 +0,0 @@ -use std::array; - -use thiserror::Error; -use tracing::{span, Level}; - -use super::air::{Component, ComponentProver, ComponentProvers, Components}; -use super::backend::BackendForChannel; -use super::channel::MerkleChannel; -use super::fields::secure_column::SECURE_EXTENSION_DEGREE; -use super::fri::FriVerificationError; -use super::pcs::{CommitmentSchemeProof, TreeVec}; -use super::vcs::ops::MerkleHasher; -use crate::core::backend::CpuBackend; -use crate::core::channel::Channel; -use crate::core::circle::CirclePoint; -use crate::core::fields::qm31::SecureField; -use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier}; -use crate::core::poly::circle::CircleEvaluation; -use crate::core::poly::BitReversedOrder; -use crate::core::vcs::verifier::MerkleVerificationError; - -#[derive(Debug)] -pub struct StarkProof { - pub commitments: TreeVec, - pub commitment_scheme_proof: CommitmentSchemeProof, -} - -#[derive(Debug)] -pub struct AdditionalProofData { - pub composition_polynomial_oods_value: SecureField, - pub composition_polynomial_random_coeff: SecureField, - pub oods_point: CirclePoint, - pub oods_quotients: Vec>, -} - -pub fn prove, MC: MerkleChannel>( - components: &[&dyn ComponentProver], - channel: &mut MC::C, - commitment_scheme: &mut CommitmentSchemeProver<'_, B, MC>, -) -> Result, ProvingError> { - let component_provers = ComponentProvers(components.to_vec()); - let trace = commitment_scheme.trace(); - - // Evaluate and commit on composition polynomial. (Hash(commitment[0]) - let random_coeff = channel.draw_felt(); - - let span = span!(Level::INFO, "Composition").entered(); - let span1 = span!(Level::INFO, "Generation").entered(); - // Linear random combination of the trace polynomials - let composition_polynomial_poly = - component_provers.compute_composition_polynomial(random_coeff, &trace); - span1.exit(); - - // Before tree_builder is only column polynomial - let mut tree_builder = commitment_scheme.tree_builder(); - // Extend it with composition polynomial - tree_builder.extend_polys(composition_polynomial_poly.to_vec()); - // Reevaluate every polynomial (including composition) and absorb the root (Transcript <-- root) - tree_builder.commit(channel); - span.exit(); - - // Draw OODS point. - let oods_point = CirclePoint::::get_random_point(channel); - - // Get mask sample points relative to oods point. - let mut sample_points = component_provers.components().mask_points(oods_point); - // Add the composition polynomial mask points. - sample_points.push(vec![vec![oods_point]; SECURE_EXTENSION_DEGREE]); - - // Prove the trace and composition OODS values, and retrieve them. - let commitment_scheme_proof = commitment_scheme.prove_values(sample_points, channel); - - let sampled_oods_values = &commitment_scheme_proof.sampled_values; - let composition_oods_eval = extract_composition_eval(sampled_oods_values).unwrap(); - - // Evaluate composition polynomial at OODS point and check that it matches the trace OODS - // values. This is a sanity check. - if composition_oods_eval - != component_provers - .components() - .eval_composition_polynomial_at_point(oods_point, sampled_oods_values, random_coeff) - { - return Err(ProvingError::ConstraintsNotSatisfied); - } - - Ok(StarkProof { - commitments: commitment_scheme.roots(), - commitment_scheme_proof, - }) -} - -pub fn verify( - components: &[&dyn Component], - channel: &mut MC::C, - commitment_scheme: &mut CommitmentSchemeVerifier, - proof: StarkProof, -) -> Result<(), VerificationError> { - let components = Components(components.to_vec()); - let random_coeff = channel.draw_felt(); - - // Read composition polynomial commitment. - commitment_scheme.commit( - *proof.commitments.last().unwrap(), - &[components.composition_log_degree_bound(); SECURE_EXTENSION_DEGREE], - channel, - ); - - // Draw OODS point z. - let oods_point = CirclePoint::::get_random_point(channel); - - // Get mask sample points relative to oods point. - let mut sample_points = components.mask_points(oods_point); - // Add the composition polynomial mask points. - sample_points.push(vec![vec![oods_point]; SECURE_EXTENSION_DEGREE]); - - let sampled_oods_values = &proof.commitment_scheme_proof.sampled_values; - // Compute h0(z) + i * h1(z) + u * h2(z) + u*i* h3(z) - let composition_oods_eval = extract_composition_eval(sampled_oods_values).map_err(|_| { - VerificationError::InvalidStructure("Unexpected sampled_values structure".to_string()) - })?; - - if composition_oods_eval - // Compute - != components.eval_composition_polynomial_at_point( - oods_point, - sampled_oods_values, - random_coeff, - ) - { - return Err(VerificationError::OodsNotMatching); - } - - commitment_scheme.verify_values(sample_points, proof.commitment_scheme_proof, channel) -} - -/// Extracts the composition trace evaluation from the mask. -fn extract_composition_eval( - mask: &TreeVec>>, -) -> Result { - // Last part of the proof - let mut composition_cols = mask.last().into_iter().flatten(); - - let coordinate_evals = array::try_from_fn(|_| { - //Each Secure element of the sampled values - let col = &**composition_cols.next().ok_or(InvalidOodsSampleStructure)?; - let [eval] = col.try_into().map_err(|_| InvalidOodsSampleStructure)?; - Ok(eval) - })?; - - // Too many columns. - if composition_cols.next().is_some() { - return Err(InvalidOodsSampleStructure); - } - // Computes [0] + [1] * i + [2] * u + [3] * i * u - - Ok(SecureField::from_partial_evals(coordinate_evals)) -} - -/// Error when the sampled values have an invalid structure. -#[derive(Clone, Copy, Debug)] -pub struct InvalidOodsSampleStructure; - -#[derive(Clone, Copy, Debug, Error)] -pub enum ProvingError { - #[error("Constraints not satisfied.")] - ConstraintsNotSatisfied, -} - -#[derive(Clone, Debug, Error)] -pub enum VerificationError { - #[error("Proof has invalid structure: {0}.")] - InvalidStructure(String), - #[error("{0} lookup values do not match.")] - InvalidLookup(String), - #[error(transparent)] - Merkle(#[from] MerkleVerificationError), - #[error( - "The composition polynomial OODS value does not match the trace OODS values - (DEEP-ALI failure)." - )] - OodsNotMatching, - #[error(transparent)] - Fri(#[from] FriVerificationError), - #[error("Proof of work verification failed.")] - ProofOfWork, -} diff --git a/Stwo_wrapper/crates/prover/src/core/queries.rs b/Stwo_wrapper/crates/prover/src/core/queries.rs deleted file mode 100644 index 934edfd..0000000 --- a/Stwo_wrapper/crates/prover/src/core/queries.rs +++ /dev/null @@ -1,237 +0,0 @@ -use std::collections::BTreeSet; -use std::ops::Deref; - -use itertools::Itertools; - -use super::channel::Channel; -use super::circle::Coset; -use super::poly::circle::CircleDomain; -use super::utils::bit_reverse_index; - -pub const UPPER_BOUND_QUERY_BYTES: usize = 4; - -/// An ordered set of query indices over a bit reversed [CircleDomain]. -#[derive(Debug, Clone)] -pub struct Queries { - pub positions: Vec, - pub log_domain_size: u32, -} - -impl Queries { - /// Randomizes a set of query indices uniformly over the range [0, 2^`log_query_size`). - pub fn generate(channel: &mut impl Channel, log_domain_size: u32, n_queries: usize) -> Self { - let mut queries = BTreeSet::new(); - let mut query_cnt = 0; - let max_query = (1 << log_domain_size) - 1; - loop { - let random_bytes = channel.draw_random_bytes(); - for chunk in random_bytes.chunks_exact(UPPER_BOUND_QUERY_BYTES) { - let query_bits = u32::from_le_bytes(chunk.try_into().unwrap()); - let quotient_query = query_bits & max_query; - queries.insert(quotient_query as usize); - query_cnt += 1; - if query_cnt == n_queries { - return Self { - positions: queries.into_iter().collect(), - log_domain_size, - }; - } - } - } - } - - // TODO docs - #[allow(clippy::missing_safety_doc)] - pub fn from_positions(positions: Vec, log_domain_size: u32) -> Self { - assert!(positions.is_sorted()); - assert!(positions.iter().all(|p| *p < (1 << log_domain_size))); - Self { - positions, - log_domain_size, - } - } - - /// Calculates the matching query indices in a folded domain (i.e each domain point is doubled) - /// given `self` (the queries of the original domain) and the number of folds between domains. - pub fn fold(&self, n_folds: u32) -> Self { - assert!(n_folds <= self.log_domain_size); - Self { - positions: self.iter().map(|q| q >> n_folds).dedup().collect(), - log_domain_size: self.log_domain_size - n_folds, - } - } - - pub fn opening_positions(&self, fri_step_size: u32) -> SparseSubCircleDomain { - assert!(fri_step_size > 0); - SparseSubCircleDomain { - domains: self - .iter() - .map(|q| SubCircleDomain { - coset_index: q >> fri_step_size, - log_size: fri_step_size, - }) - .dedup() - .collect(), - large_domain_log_size: self.log_domain_size, - } - } -} - -impl Deref for Queries { - type Target = Vec; - - fn deref(&self) -> &Self::Target { - &self.positions - } -} - -#[derive(Debug, Eq, PartialEq)] -pub struct SparseSubCircleDomain { - pub domains: Vec, - pub large_domain_log_size: u32, -} - -impl SparseSubCircleDomain { - pub fn flatten(&self) -> Vec { - self.iter() - .flat_map(|sub_circle_domain| sub_circle_domain.to_decommitment_positions()) - .collect() - } -} - -impl Deref for SparseSubCircleDomain { - type Target = Vec; - - fn deref(&self) -> &Self::Target { - &self.domains - } -} - -/// Represents a circle domain relative to a larger circle domain. The `initial_index` is the bit -/// reversed query index in the larger domain. -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] -pub struct SubCircleDomain { - pub coset_index: usize, - pub log_size: u32, -} - -impl SubCircleDomain { - /// Calculates the decommitment positions needed for each query given the fri step size. - pub fn to_decommitment_positions(&self) -> Vec { - (self.coset_index << self.log_size..(self.coset_index + 1) << self.log_size).collect() - } - - /// Returns the represented [CircleDomain]. - pub fn to_circle_domain(&self, query_domain: &CircleDomain) -> CircleDomain { - let query = bit_reverse_index(self.coset_index << self.log_size, query_domain.log_size()); - let initial_index = query_domain.index_at(query); - let half_coset = Coset::new(initial_index, self.log_size - 1); - CircleDomain::new(half_coset) - } -} - -#[cfg(test)] -mod tests { - use crate::core::channel::Blake2sChannel; - use crate::core::poly::circle::CanonicCoset; - use crate::core::queries::Queries; - use crate::core::utils::bit_reverse; - - #[test] - fn test_generate_queries() { - let channel = &mut Blake2sChannel::default(); - let log_query_size = 31; - let n_queries = 100; - - let queries = Queries::generate(channel, log_query_size, n_queries); - - assert!(queries.len() == n_queries); - for query in queries.iter() { - assert!(*query < 1 << log_query_size); - } - } - - #[test] - pub fn test_folded_queries() { - let log_domain_size = 7; - let domain = CanonicCoset::new(log_domain_size).circle_domain(); - let mut values = domain.iter().collect::>(); - bit_reverse(&mut values); - - let log_folded_domain_size = 5; - let folded_domain = CanonicCoset::new(log_folded_domain_size).circle_domain(); - let mut folded_values = folded_domain.iter().collect::>(); - bit_reverse(&mut folded_values); - - // Generate all possible queries. - let queries = Queries { - positions: (0..1 << log_domain_size).collect(), - log_domain_size, - }; - let n_folds = log_domain_size - log_folded_domain_size; - let ratio = 1 << n_folds; - - let folded_queries = queries.fold(n_folds); - let repeated_folded_queries = folded_queries - .iter() - .flat_map(|q| std::iter::repeat(q).take(ratio)); - for (query, folded_query) in queries.iter().zip(repeated_folded_queries) { - // Check only the x coordinate since folding might give you the conjugate point. - assert_eq!( - values[*query].repeated_double(n_folds).x, - folded_values[*folded_query].x - ); - } - } - - #[test] - pub fn test_conjugate_queries() { - let channel = &mut Blake2sChannel::default(); - let log_domain_size = 7; - let domain = CanonicCoset::new(log_domain_size).circle_domain(); - let mut values = domain.iter().collect::>(); - bit_reverse(&mut values); - - // Test random queries one by one because the conjugate queries are sorted. - for _ in 0..100 { - let query = Queries::generate(channel, log_domain_size, 1); - let conjugate_query = query[0] ^ 1; - let query_and_conjugate = query.opening_positions(1).flatten(); - let mut expected_query_and_conjugate = vec![query[0], conjugate_query]; - expected_query_and_conjugate.sort(); - assert_eq!(query_and_conjugate, expected_query_and_conjugate); - assert_eq!(values[query[0]], values[conjugate_query].conjugate()); - } - } - - #[test] - pub fn test_decommitment_positions() { - let channel = &mut Blake2sChannel::default(); - let log_domain_size = 31; - let n_queries = 100; - let fri_step_size = 3; - - let queries = Queries::generate(channel, log_domain_size, n_queries); - let queries_with_added_positions = queries.opening_positions(fri_step_size).flatten(); - - assert!(queries_with_added_positions.is_sorted()); - assert_eq!( - queries_with_added_positions.len(), - n_queries * (1 << fri_step_size) - ); - } - - #[test] - pub fn test_dedup_decommitment_positions() { - let log_domain_size = 7; - - // Generate all possible queries. - let queries = Queries { - positions: (0..1 << log_domain_size).collect(), - log_domain_size, - }; - let queries_with_conjugates = queries.opening_positions(log_domain_size - 2).flatten(); - - assert_eq!(*queries, *queries_with_conjugates); - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/test_utils.rs b/Stwo_wrapper/crates/prover/src/core/test_utils.rs deleted file mode 100644 index 5ebaeaf..0000000 --- a/Stwo_wrapper/crates/prover/src/core/test_utils.rs +++ /dev/null @@ -1,17 +0,0 @@ -use super::backend::cpu::CpuCircleEvaluation; -use super::channel::Blake2sChannel; -use super::fields::m31::BaseField; -use super::fields::qm31::SecureField; - -pub fn secure_eval_to_base_eval( - eval: &CpuCircleEvaluation, -) -> CpuCircleEvaluation { - CpuCircleEvaluation::new( - eval.domain, - eval.values.iter().map(|x| x.to_m31_array()[0]).collect(), - ) -} - -pub fn test_channel() -> Blake2sChannel { - Blake2sChannel::default() -} diff --git a/Stwo_wrapper/crates/prover/src/core/utils.rs b/Stwo_wrapper/crates/prover/src/core/utils.rs deleted file mode 100644 index 334edb7..0000000 --- a/Stwo_wrapper/crates/prover/src/core/utils.rs +++ /dev/null @@ -1,327 +0,0 @@ -use std::iter::Peekable; -use std::ops::{Add, Mul, Sub}; - -use num_traits::{One, Zero}; - -use super::circle::CirclePoint; -use super::constraints::point_vanishing; -use super::fields::m31::BaseField; -use super::fields::qm31::SecureField; -use super::fields::{Field, FieldExpOps}; -use super::poly::circle::CircleDomain; - -pub trait IteratorMutExt<'a, T: 'a>: Iterator { - fn assign(self, other: impl IntoIterator) - where - Self: Sized, - { - self.zip(other).for_each(|(a, b)| *a = b); - } -} - -impl<'a, T: 'a, I: Iterator> IteratorMutExt<'a, T> for I {} - -/// An iterator that takes elements from the underlying [Peekable] while the predicate is true. -/// Used to implement [PeekableExt::peek_take_while]. -pub struct PeekTakeWhile<'a, I: Iterator, P: FnMut(&I::Item) -> bool> { - iter: &'a mut Peekable, - predicate: P, -} -impl<'a, I: Iterator, P: FnMut(&I::Item) -> bool> Iterator for PeekTakeWhile<'a, I, P> { - type Item = I::Item; - - fn next(&mut self) -> Option { - self.iter.next_if(&mut self.predicate) - } -} -pub trait PeekableExt<'a, I: Iterator> { - /// Returns an iterator that takes elements from the underlying [Peekable] while the predicate - /// is true. - /// Unlike [Iterator::take_while], this iterator does not consume the first element that does - /// not satisfy the predicate. - fn peek_take_while bool>( - &'a mut self, - predicate: P, - ) -> PeekTakeWhile<'a, I, P>; -} -impl<'a, I: Iterator> PeekableExt<'a, I> for Peekable { - fn peek_take_while bool>( - &'a mut self, - predicate: P, - ) -> PeekTakeWhile<'a, I, P> { - PeekTakeWhile { - iter: self, - predicate, - } - } -} - -/// Returns the bit reversed index of `i` which is represented by `log_size` bits. -pub fn bit_reverse_index(i: usize, log_size: u32) -> usize { - if log_size == 0 { - return i; - } - i.reverse_bits() >> (usize::BITS - log_size) -} - -/// Returns the index of the previous element in a bit reversed -/// [super::poly::circle::CircleEvaluation] of log size `eval_log_size` relative to a smaller domain -/// of size `domain_log_size`. -pub fn previous_bit_reversed_circle_domain_index( - i: usize, - domain_log_size: u32, - eval_log_size: u32, -) -> usize { - offset_bit_reversed_circle_domain_index(i, domain_log_size, eval_log_size, -1) -} - -/// Returns the index of the offset element in a bit reversed -/// [super::poly::circle::CircleEvaluation] of log size `eval_log_size` relative to a smaller domain -/// of size `domain_log_size`. -pub fn offset_bit_reversed_circle_domain_index( - i: usize, - domain_log_size: u32, - eval_log_size: u32, - offset: isize, -) -> usize { - let mut prev_index = bit_reverse_index(i, eval_log_size); - let half_size = 1 << (eval_log_size - 1); - let step_size = offset * (1 << (eval_log_size - domain_log_size - 1)) as isize; - if prev_index < half_size { - prev_index = (prev_index as isize + step_size).rem_euclid(half_size as isize) as usize; - } else { - prev_index = - ((prev_index as isize - step_size).rem_euclid(half_size as isize) as usize) + half_size; - } - bit_reverse_index(prev_index, eval_log_size) -} - -// TODO(AlonH): Pair both functions below with bit reverse. Consider removing both and calculating -// the indices instead. -pub(crate) fn circle_domain_order_to_coset_order(values: &[BaseField]) -> Vec { - let n = values.len(); - let mut coset_order = vec![]; - for i in 0..(n / 2) { - coset_order.push(values[i]); - coset_order.push(values[n - 1 - i]); - } - coset_order -} - -pub(crate) fn coset_order_to_circle_domain_order(values: &[F]) -> Vec { - let mut circle_domain_order = Vec::with_capacity(values.len()); - let n = values.len(); - let half_len = n / 2; - for i in 0..half_len { - circle_domain_order.push(values[i << 1]); - } - for i in 0..half_len { - circle_domain_order.push(values[n - 1 - (i << 1)]); - } - circle_domain_order -} - -/// Converts an index within a [`Coset`] to the corresponding index in a [`CircleDomain`]. -/// -/// [`CircleDomain`]: crate::core::poly::circle::CircleDomain -/// [`Coset`]: crate::core::circle::Coset -pub fn coset_index_to_circle_domain_index(coset_index: usize, log_domain_size: u32) -> usize { - if coset_index % 2 == 0 { - coset_index / 2 - } else { - ((2 << log_domain_size) - coset_index) / 2 - } -} - -/// Performs a naive bit-reversal permutation inplace. -/// -/// # Panics -/// -/// Panics if the length of the slice is not a power of two. -// TODO: Implement cache friendly implementation. -// TODO(spapini): Move this to the cpu backend. -pub fn bit_reverse(v: &mut [T]) { - let n = v.len(); - assert!(n.is_power_of_two()); - let log_n = n.ilog2(); - for i in 0..n { - let j = bit_reverse_index(i, log_n); - if j > i { - v.swap(i, j); - } - } -} - -pub fn generate_secure_powers(felt: SecureField, n_powers: usize) -> Vec { - (0..n_powers) - .scan(SecureField::one(), |acc, _| { - let res = *acc; - *acc *= felt; - Some(res) - }) - .collect() -} - -/// Securely combines the given values using the given random alpha and z. -/// Alpha and z should be secure field elements for soundness. -pub fn shifted_secure_combination(values: &[F], alpha: EF, z: EF) -> EF -where - EF: Copy + Zero + Mul + Add + Sub, -{ - let res = values - .iter() - .fold(EF::zero(), |acc, &value| acc * alpha + value); - res - z -} - -pub fn point_vanish_denominator_inverses( - domain: CircleDomain, - vanish_point: CirclePoint, -) -> Vec { - let mut denoms = vec![]; - for point in domain.iter() { - // TODO(AlonH): Use `point_vanishing_fraction` instead of `point_vanishing` everywhere. - denoms.push(point_vanishing(vanish_point, point)); - } - bit_reverse(&mut denoms); - let mut denom_inverses = vec![BaseField::zero(); 1 << (domain.log_size())]; - BaseField::batch_inverse(&denoms, &mut denom_inverses); - denom_inverses -} - -#[cfg(test)] -mod tests { - use itertools::Itertools; - use num_traits::One; - - use super::{ - offset_bit_reversed_circle_domain_index, previous_bit_reversed_circle_domain_index, - }; - use crate::core::backend::cpu::CpuCircleEvaluation; - use crate::core::fields::qm31::SecureField; - use crate::core::fields::FieldExpOps; - use crate::core::poly::circle::CanonicCoset; - use crate::core::poly::NaturalOrder; - use crate::core::utils::bit_reverse; - use crate::{m31, qm31}; - - #[test] - fn bit_reverse_works() { - let mut data = [0, 1, 2, 3, 4, 5, 6, 7]; - bit_reverse(&mut data); - assert_eq!(data, [0, 4, 2, 6, 1, 5, 3, 7]); - } - - #[test] - #[should_panic] - fn bit_reverse_non_power_of_two_size_fails() { - let mut data = [0, 1, 2, 3, 4, 5]; - bit_reverse(&mut data); - } - - #[test] - fn generate_secure_powers_works() { - let felt = qm31!(1, 2, 3, 4); - let n_powers = 10; - - let powers = super::generate_secure_powers(felt, n_powers); - - assert_eq!(powers.len(), n_powers); - assert_eq!(powers[0], SecureField::one()); - assert_eq!(powers[1], felt); - assert_eq!(powers[7], felt.pow(7)); - } - - #[test] - fn generate_empty_secure_powers_works() { - let felt = qm31!(1, 2, 3, 4); - let max_log_size = 0; - - let powers = super::generate_secure_powers(felt, max_log_size); - - assert_eq!(powers, vec![]); - } - - #[test] - fn test_offset_bit_reversed_circle_domain_index() { - let domain_log_size = 3; - let eval_log_size = 6; - let initial_index = 5; - - let actual = offset_bit_reversed_circle_domain_index( - initial_index, - domain_log_size, - eval_log_size, - -2, - ); - let expected_prev = previous_bit_reversed_circle_domain_index( - initial_index, - domain_log_size, - eval_log_size, - ); - let expected_prev2 = previous_bit_reversed_circle_domain_index( - expected_prev, - domain_log_size, - eval_log_size, - ); - assert_eq!(actual, expected_prev2); - } - - #[test] - fn test_previous_bit_reversed_circle_domain_index() { - let log_size = 4; - let n = 1 << log_size; - let domain = CanonicCoset::new(log_size).circle_domain(); - let values = (0..n).map(|i| m31!(i as u32)).collect_vec(); - let evaluation = CpuCircleEvaluation::<_, NaturalOrder>::new(domain, values.clone()); - let bit_reversed_evaluation = evaluation.clone().bit_reverse(); - - // 2 · 14 - // · | · - // 13 | 1 - // · | · - // 3 | 15 - // · | · - // 12 | 0 - // ·--------------|---------------· - // 4 | 8 - // · | · - // 11 | 7 - // · | · - // 5 | 9 - // · | · - // 10 · 6 - let neighbor_pairs = (0..n) - .map(|index| { - let prev_index = - previous_bit_reversed_circle_domain_index(index, log_size - 3, log_size); - ( - bit_reversed_evaluation[index], - bit_reversed_evaluation[prev_index], - ) - }) - .sorted() - .collect_vec(); - let mut expected_neighbor_pairs = vec![ - (m31!(0), m31!(4)), - (m31!(15), m31!(11)), - (m31!(1), m31!(5)), - (m31!(14), m31!(10)), - (m31!(2), m31!(6)), - (m31!(13), m31!(9)), - (m31!(3), m31!(7)), - (m31!(12), m31!(8)), - (m31!(4), m31!(0)), - (m31!(11), m31!(15)), - (m31!(5), m31!(1)), - (m31!(10), m31!(14)), - (m31!(6), m31!(2)), - (m31!(9), m31!(13)), - (m31!(7), m31!(3)), - (m31!(8), m31!(12)), - ]; - expected_neighbor_pairs.sort(); - - assert_eq!(neighbor_pairs, expected_neighbor_pairs); - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/vcs/blake2_hash.rs b/Stwo_wrapper/crates/prover/src/core/vcs/blake2_hash.rs deleted file mode 100644 index b702fcd..0000000 --- a/Stwo_wrapper/crates/prover/src/core/vcs/blake2_hash.rs +++ /dev/null @@ -1,139 +0,0 @@ -use std::fmt; - -use blake2::{Blake2s256, Digest}; -use bytemuck::{Pod, Zeroable}; -use serde::{Deserialize, Serialize}; - -// Wrapper for the blake2s hash type. -#[repr(C, align(32))] -#[derive(Clone, Copy, PartialEq, Default, Eq, Pod, Zeroable, Deserialize, Serialize)] -pub struct Blake2sHash(pub [u8; 32]); - -impl From for Vec { - fn from(value: Blake2sHash) -> Self { - Vec::from(value.0) - } -} - -impl From> for Blake2sHash { - fn from(value: Vec) -> Self { - Self( - value - .try_into() - .expect("Failed converting Vec to Blake2Hash type"), - ) - } -} - -impl From<&[u8]> for Blake2sHash { - fn from(value: &[u8]) -> Self { - Self( - value - .try_into() - .expect("Failed converting &[u8] to Blake2sHash Type!"), - ) - } -} - -impl AsRef<[u8]> for Blake2sHash { - fn as_ref(&self) -> &[u8] { - &self.0 - } -} - -impl From for [u8; 32] { - fn from(val: Blake2sHash) -> Self { - val.0 - } -} - -impl fmt::Display for Blake2sHash { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(&hex::encode(self.0)) - } -} - -impl fmt::Debug for Blake2sHash { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - ::fmt(self, f) - } -} - -impl super::hash::Hash for Blake2sHash { - fn to_bytes(&self) -> Vec { - self.0.to_vec() - } - - fn from_bytes(bytes: &[u8]) -> Self { - bytes.into() - } -} - -// Wrapper for the blake2s Hashing functionalities. -#[derive(Clone, Debug, Default)] -pub struct Blake2sHasher { - state: Blake2s256, -} - -impl Blake2sHasher { - pub fn new() -> Self { - Self { - state: Blake2s256::new(), - } - } - - pub fn update(&mut self, data: &[u8]) { - blake2::Digest::update(&mut self.state, data); - } - - pub fn finalize(self) -> Blake2sHash { - Blake2sHash(self.state.finalize().into()) - } - - pub fn concat_and_hash(v1: &Blake2sHash, v2: &Blake2sHash) -> Blake2sHash { - let mut hasher = Self::new(); - hasher.update(v1.as_ref()); - hasher.update(v2.as_ref()); - hasher.finalize() - } - - pub fn hash(data: &[u8]) -> Blake2sHash { - let mut hasher = Self::new(); - hasher.update(data); - hasher.finalize() - } -} - -#[cfg(test)] -mod tests { - use blake2::Digest; - - use super::{Blake2sHash, Blake2sHasher}; - - impl Blake2sHasher { - fn finalize_reset(&mut self) -> Blake2sHash { - Blake2sHash(self.state.finalize_reset().into()) - } - } - - #[test] - fn single_hash_test() { - let hash_a = Blake2sHasher::hash(b"a"); - assert_eq!( - hash_a.to_string(), - "4a0d129873403037c2cd9b9048203687f6233fb6738956e0349bd4320fec3e90" - ); - } - - #[test] - fn hash_state_test() { - let mut state = Blake2sHasher::new(); - state.update(b"a"); - state.update(b"b"); - let hash = state.finalize_reset(); - let hash_empty = state.finalize(); - - assert_eq!(hash.to_string(), Blake2sHasher::hash(b"ab").to_string()); - assert_eq!(hash_empty.to_string(), Blake2sHasher::hash(b"").to_string()); - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/vcs/blake2_merkle.rs b/Stwo_wrapper/crates/prover/src/core/vcs/blake2_merkle.rs deleted file mode 100644 index 293ed4a..0000000 --- a/Stwo_wrapper/crates/prover/src/core/vcs/blake2_merkle.rs +++ /dev/null @@ -1,148 +0,0 @@ -use num_traits::Zero; -use serde::{Deserialize, Serialize}; - -use super::blake2_hash::Blake2sHash; -use super::blake2s_ref::compress; -use super::ops::MerkleHasher; -use crate::core::channel::{Blake2sChannel, MerkleChannel}; -use crate::core::fields::m31::BaseField; - -#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Deserialize, Serialize)] -pub struct Blake2sMerkleHasher; -impl MerkleHasher for Blake2sMerkleHasher { - type Hash = Blake2sHash; - - fn hash_node( - children_hashes: Option<(Self::Hash, Self::Hash)>, - column_values: &[BaseField], - ) -> Self::Hash { - let mut state = [0; 8]; - if let Some((left, right)) = children_hashes { - state = compress( - state, - unsafe { std::mem::transmute([left, right]) }, - 0, - 0, - 0, - 0, - ); - } - let rem = 15 - ((column_values.len() + 15) % 16); - let padded_values = column_values - .iter() - .copied() - .chain(std::iter::repeat(BaseField::zero()).take(rem)); - for chunk in padded_values.array_chunks::<16>() { - state = compress(state, unsafe { std::mem::transmute(chunk) }, 0, 0, 0, 0); - } - state.map(|x| x.to_le_bytes()).flatten().into() - } -} - -#[derive(Default)] -pub struct Blake2sMerkleChannel; - -impl MerkleChannel for Blake2sMerkleChannel { - type C = Blake2sChannel; - type H = Blake2sMerkleHasher; - - fn mix_root(channel: &mut Self::C, root: ::Hash) { - channel.update_digest(super::blake2_hash::Blake2sHasher::concat_and_hash( - &channel.digest(), - &root, - )); - } -} - -#[cfg(test)] -mod tests { - use num_traits::Zero; - - use super::Blake2sMerkleChannel; - use crate::core::channel::{Blake2sChannel, MerkleChannel}; - use crate::core::fields::m31::BaseField; - use crate::core::vcs::blake2_merkle::{Blake2sHash, Blake2sMerkleHasher}; - use crate::core::vcs::test_utils::prepare_merkle; - use crate::core::vcs::verifier::MerkleVerificationError; - - #[test] - fn test_merkle_success() { - let (queries, decommitment, values, verifier) = prepare_merkle::(); - - verifier.verify(queries, values, decommitment).unwrap(); - } - - #[test] - fn test_merkle_invalid_witness() { - let (queries, mut decommitment, values, verifier) = prepare_merkle::(); - decommitment.hash_witness[4] = Blake2sHash::default(); - - assert_eq!( - verifier.verify(queries, values, decommitment).unwrap_err(), - MerkleVerificationError::RootMismatch - ); - } - - #[test] - fn test_merkle_invalid_value() { - let (queries, decommitment, mut values, verifier) = prepare_merkle::(); - values[3][2] = BaseField::zero(); - - assert_eq!( - verifier.verify(queries, values, decommitment).unwrap_err(), - MerkleVerificationError::RootMismatch - ); - } - - #[test] - fn test_merkle_witness_too_short() { - let (queries, mut decommitment, values, verifier) = prepare_merkle::(); - decommitment.hash_witness.pop(); - - assert_eq!( - verifier.verify(queries, values, decommitment).unwrap_err(), - MerkleVerificationError::WitnessTooShort - ); - } - - #[test] - fn test_merkle_witness_too_long() { - let (queries, mut decommitment, values, verifier) = prepare_merkle::(); - decommitment.hash_witness.push(Blake2sHash::default()); - - assert_eq!( - verifier.verify(queries, values, decommitment).unwrap_err(), - MerkleVerificationError::WitnessTooLong - ); - } - - #[test] - fn test_merkle_column_values_too_long() { - let (queries, decommitment, mut values, verifier) = prepare_merkle::(); - values[3].push(BaseField::zero()); - - assert_eq!( - verifier.verify(queries, values, decommitment).unwrap_err(), - MerkleVerificationError::ColumnValuesTooLong - ); - } - - #[test] - fn test_merkle_column_values_too_short() { - let (queries, decommitment, mut values, verifier) = prepare_merkle::(); - values[3].pop(); - - assert_eq!( - verifier.verify(queries, values, decommitment).unwrap_err(), - MerkleVerificationError::ColumnValuesTooShort - ); - } - - #[test] - fn test_merkle_channel() { - let mut channel = Blake2sChannel::default(); - let (_queries, _decommitment, _values, verifier) = prepare_merkle::(); - Blake2sMerkleChannel::mix_root(&mut channel, verifier.root); - assert_eq!(channel.channel_time.n_challenges, 1); - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/vcs/blake2s_ref.rs b/Stwo_wrapper/crates/prover/src/core/vcs/blake2s_ref.rs deleted file mode 100644 index ab32ea6..0000000 --- a/Stwo_wrapper/crates/prover/src/core/vcs/blake2s_ref.rs +++ /dev/null @@ -1,217 +0,0 @@ -//! An AVX512 implementation of the BLAKE2s compression function. -//! Based on . - -pub const IV: [u32; 8] = [ - 0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19, -]; - -pub const SIGMA: [[u8; 16]; 10] = [ - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], - [14, 10, 4, 8, 9, 15, 13, 6, 1, 12, 0, 2, 11, 7, 5, 3], - [11, 8, 12, 0, 5, 2, 15, 13, 10, 14, 3, 6, 7, 1, 9, 4], - [7, 9, 3, 1, 13, 12, 11, 14, 2, 6, 5, 10, 4, 0, 15, 8], - [9, 0, 5, 7, 2, 4, 10, 15, 14, 1, 11, 12, 6, 8, 3, 13], - [2, 12, 6, 10, 0, 11, 8, 3, 4, 13, 7, 5, 15, 14, 1, 9], - [12, 5, 1, 15, 14, 13, 4, 10, 0, 7, 6, 3, 9, 2, 8, 11], - [13, 11, 7, 14, 12, 1, 3, 9, 5, 0, 15, 4, 8, 6, 2, 10], - [6, 15, 14, 9, 11, 3, 0, 8, 12, 2, 13, 7, 1, 4, 10, 5], - [10, 2, 8, 4, 7, 6, 1, 5, 15, 11, 9, 14, 3, 12, 13, 0], -]; - -#[inline(always)] -fn add(a: u32, b: u32) -> u32 { - a.wrapping_add(b) -} - -#[inline(always)] -fn xor(a: u32, b: u32) -> u32 { - a ^ b -} - -#[inline(always)] -fn rot16(x: u32) -> u32 { - (x >> 16) | (x << (32 - 16)) -} - -#[inline(always)] -fn rot12(x: u32) -> u32 { - (x >> 12) | (x << (32 - 12)) -} - -#[inline(always)] -fn rot8(x: u32) -> u32 { - (x >> 8) | (x << (32 - 8)) -} - -#[inline(always)] -fn rot7(x: u32) -> u32 { - (x >> 7) | (x << (32 - 7)) -} - -#[inline(always)] -fn round(v: &mut [u32; 16], m: [u32; 16], r: usize) { - v[0] = add(v[0], m[SIGMA[r][0] as usize]); - v[1] = add(v[1], m[SIGMA[r][2] as usize]); - v[2] = add(v[2], m[SIGMA[r][4] as usize]); - v[3] = add(v[3], m[SIGMA[r][6] as usize]); - v[0] = add(v[0], v[4]); - v[1] = add(v[1], v[5]); - v[2] = add(v[2], v[6]); - v[3] = add(v[3], v[7]); - v[12] = xor(v[12], v[0]); - v[13] = xor(v[13], v[1]); - v[14] = xor(v[14], v[2]); - v[15] = xor(v[15], v[3]); - v[12] = rot16(v[12]); - v[13] = rot16(v[13]); - v[14] = rot16(v[14]); - v[15] = rot16(v[15]); - v[8] = add(v[8], v[12]); - v[9] = add(v[9], v[13]); - v[10] = add(v[10], v[14]); - v[11] = add(v[11], v[15]); - v[4] = xor(v[4], v[8]); - v[5] = xor(v[5], v[9]); - v[6] = xor(v[6], v[10]); - v[7] = xor(v[7], v[11]); - v[4] = rot12(v[4]); - v[5] = rot12(v[5]); - v[6] = rot12(v[6]); - v[7] = rot12(v[7]); - v[0] = add(v[0], m[SIGMA[r][1] as usize]); - v[1] = add(v[1], m[SIGMA[r][3] as usize]); - v[2] = add(v[2], m[SIGMA[r][5] as usize]); - v[3] = add(v[3], m[SIGMA[r][7] as usize]); - v[0] = add(v[0], v[4]); - v[1] = add(v[1], v[5]); - v[2] = add(v[2], v[6]); - v[3] = add(v[3], v[7]); - v[12] = xor(v[12], v[0]); - v[13] = xor(v[13], v[1]); - v[14] = xor(v[14], v[2]); - v[15] = xor(v[15], v[3]); - v[12] = rot8(v[12]); - v[13] = rot8(v[13]); - v[14] = rot8(v[14]); - v[15] = rot8(v[15]); - v[8] = add(v[8], v[12]); - v[9] = add(v[9], v[13]); - v[10] = add(v[10], v[14]); - v[11] = add(v[11], v[15]); - v[4] = xor(v[4], v[8]); - v[5] = xor(v[5], v[9]); - v[6] = xor(v[6], v[10]); - v[7] = xor(v[7], v[11]); - v[4] = rot7(v[4]); - v[5] = rot7(v[5]); - v[6] = rot7(v[6]); - v[7] = rot7(v[7]); - - v[0] = add(v[0], m[SIGMA[r][8] as usize]); - v[1] = add(v[1], m[SIGMA[r][10] as usize]); - v[2] = add(v[2], m[SIGMA[r][12] as usize]); - v[3] = add(v[3], m[SIGMA[r][14] as usize]); - v[0] = add(v[0], v[5]); - v[1] = add(v[1], v[6]); - v[2] = add(v[2], v[7]); - v[3] = add(v[3], v[4]); - v[15] = xor(v[15], v[0]); - v[12] = xor(v[12], v[1]); - v[13] = xor(v[13], v[2]); - v[14] = xor(v[14], v[3]); - v[15] = rot16(v[15]); - v[12] = rot16(v[12]); - v[13] = rot16(v[13]); - v[14] = rot16(v[14]); - v[10] = add(v[10], v[15]); - v[11] = add(v[11], v[12]); - v[8] = add(v[8], v[13]); - v[9] = add(v[9], v[14]); - v[5] = xor(v[5], v[10]); - v[6] = xor(v[6], v[11]); - v[7] = xor(v[7], v[8]); - v[4] = xor(v[4], v[9]); - v[5] = rot12(v[5]); - v[6] = rot12(v[6]); - v[7] = rot12(v[7]); - v[4] = rot12(v[4]); - v[0] = add(v[0], m[SIGMA[r][9] as usize]); - v[1] = add(v[1], m[SIGMA[r][11] as usize]); - v[2] = add(v[2], m[SIGMA[r][13] as usize]); - v[3] = add(v[3], m[SIGMA[r][15] as usize]); - v[0] = add(v[0], v[5]); - v[1] = add(v[1], v[6]); - v[2] = add(v[2], v[7]); - v[3] = add(v[3], v[4]); - v[15] = xor(v[15], v[0]); - v[12] = xor(v[12], v[1]); - v[13] = xor(v[13], v[2]); - v[14] = xor(v[14], v[3]); - v[15] = rot8(v[15]); - v[12] = rot8(v[12]); - v[13] = rot8(v[13]); - v[14] = rot8(v[14]); - v[10] = add(v[10], v[15]); - v[11] = add(v[11], v[12]); - v[8] = add(v[8], v[13]); - v[9] = add(v[9], v[14]); - v[5] = xor(v[5], v[10]); - v[6] = xor(v[6], v[11]); - v[7] = xor(v[7], v[8]); - v[4] = xor(v[4], v[9]); - v[5] = rot7(v[5]); - v[6] = rot7(v[6]); - v[7] = rot7(v[7]); - v[4] = rot7(v[4]); -} - -/// Performs a Blake2s compression. -pub fn compress( - h_vecs: [u32; 8], - msg_vecs: [u32; 16], - count_low: u32, - count_high: u32, - lastblock: u32, - lastnode: u32, -) -> [u32; 8] { - let mut v = [ - h_vecs[0], - h_vecs[1], - h_vecs[2], - h_vecs[3], - h_vecs[4], - h_vecs[5], - h_vecs[6], - h_vecs[7], - IV[0], - IV[1], - IV[2], - IV[3], - xor(IV[4], count_low), - xor(IV[5], count_high), - xor(IV[6], lastblock), - xor(IV[7], lastnode), - ]; - - round(&mut v, msg_vecs, 0); - round(&mut v, msg_vecs, 1); - round(&mut v, msg_vecs, 2); - round(&mut v, msg_vecs, 3); - round(&mut v, msg_vecs, 4); - round(&mut v, msg_vecs, 5); - round(&mut v, msg_vecs, 6); - round(&mut v, msg_vecs, 7); - round(&mut v, msg_vecs, 8); - round(&mut v, msg_vecs, 9); - - [ - xor(xor(h_vecs[0], v[0]), v[8]), - xor(xor(h_vecs[1], v[1]), v[9]), - xor(xor(h_vecs[2], v[2]), v[10]), - xor(xor(h_vecs[3], v[3]), v[11]), - xor(xor(h_vecs[4], v[4]), v[12]), - xor(xor(h_vecs[5], v[5]), v[13]), - xor(xor(h_vecs[6], v[6]), v[14]), - xor(xor(h_vecs[7], v[7]), v[15]), - ] -} diff --git a/Stwo_wrapper/crates/prover/src/core/vcs/blake3_hash.rs b/Stwo_wrapper/crates/prover/src/core/vcs/blake3_hash.rs deleted file mode 100644 index e9b9d0b..0000000 --- a/Stwo_wrapper/crates/prover/src/core/vcs/blake3_hash.rs +++ /dev/null @@ -1,132 +0,0 @@ -use std::fmt; - -use serde::{Deserialize, Serialize}; - -use crate::core::vcs::hash::Hash; - -// Wrapper for the blake3 hash type. -#[derive(Clone, Copy, PartialEq, Default, Eq, Serialize, Deserialize)] -pub struct Blake3Hash([u8; 32]); - -impl From for Vec { - fn from(value: Blake3Hash) -> Self { - Vec::from(value.0) - } -} - -impl From> for Blake3Hash { - fn from(value: Vec) -> Self { - Self( - value - .try_into() - .expect("Failed converting Vec to Blake3Hash Type!"), - ) - } -} - -impl From<&[u8]> for Blake3Hash { - fn from(value: &[u8]) -> Self { - Self( - value - .try_into() - .expect("Failed converting &[u8] to Blake3Hash Type!"), - ) - } -} - -impl AsRef<[u8]> for Blake3Hash { - fn as_ref(&self) -> &[u8] { - &self.0 - } -} - -impl fmt::Display for Blake3Hash { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(&hex::encode(self.0)) - } -} - -impl fmt::Debug for Blake3Hash { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - ::fmt(self, f) - } -} - -impl Hash for Blake3Hash { - fn to_bytes(&self) -> Vec { - self.0.to_vec() - } - - fn from_bytes(bytes: &[u8]) -> Self { - bytes.into() - } -} - -// Wrapper for the blake3 Hashing functionalities. -#[derive(Clone, Default)] -pub struct Blake3Hasher { - state: blake3::Hasher, -} - -impl Blake3Hasher { - pub fn new() -> Self { - Self { - state: blake3::Hasher::new(), - } - } - pub fn update(&mut self, data: &[u8]) { - self.state.update(data); - } - - pub fn finalize(self) -> Blake3Hash { - Blake3Hash(self.state.finalize().into()) - } - - pub fn concat_and_hash(v1: &Blake3Hash, v2: &Blake3Hash) -> Blake3Hash { - let mut hasher = Self::new(); - hasher.update(v1.as_ref()); - hasher.update(v2.as_ref()); - hasher.finalize() - } - - pub fn hash(data: &[u8]) -> Blake3Hash { - let mut hasher = Self::new(); - hasher.update(data); - hasher.finalize() - } -} - -#[cfg(test)] -impl Blake3Hasher { - fn finalize_reset(&mut self) -> Blake3Hash { - let res = Blake3Hash(self.state.finalize().into()); - self.state.reset(); - res - } -} - -#[cfg(test)] -mod tests { - use crate::core::vcs::blake3_hash::Blake3Hasher; - - #[test] - fn single_hash_test() { - let hash_a = Blake3Hasher::hash(b"a"); - assert_eq!( - hash_a.to_string(), - "17762fddd969a453925d65717ac3eea21320b66b54342fde15128d6caf21215f" - ); - } - - #[test] - fn hash_state_test() { - let mut state = Blake3Hasher::new(); - state.update(b"a"); - state.update(b"b"); - let hash = state.finalize_reset(); - let hash_empty = state.finalize(); - - assert_eq!(hash.to_string(), Blake3Hasher::hash(b"ab").to_string()); - assert_eq!(hash_empty.to_string(), Blake3Hasher::hash(b"").to_string()) - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/vcs/hash.rs b/Stwo_wrapper/crates/prover/src/core/vcs/hash.rs deleted file mode 100644 index 066a5d1..0000000 --- a/Stwo_wrapper/crates/prover/src/core/vcs/hash.rs +++ /dev/null @@ -1,15 +0,0 @@ -use std::fmt::{Debug, Display}; - -pub trait Hash: - Copy - + Default - + Display - + Debug - + Eq - + Send - + Sync - + 'static -{ - fn to_bytes(&self) -> Vec; - fn from_bytes(bytes: &[u8]) -> Self; -} \ No newline at end of file diff --git a/Stwo_wrapper/crates/prover/src/core/vcs/mod.rs b/Stwo_wrapper/crates/prover/src/core/vcs/mod.rs deleted file mode 100644 index 7e19129..0000000 --- a/Stwo_wrapper/crates/prover/src/core/vcs/mod.rs +++ /dev/null @@ -1,20 +0,0 @@ -//! Vector commitment scheme (VCS) module. - -pub mod blake2_hash; -pub mod blake2_merkle; -pub mod blake2s_ref; -pub mod blake3_hash; -pub mod hash; -pub mod ops; -#[cfg(not(target_arch = "wasm32"))] -pub mod poseidon252_merkle; - -#[cfg(not(target_arch = "wasm32"))] -pub mod poseidon_bls_merkle; -pub mod prover; -mod utils; -pub mod verifier; - -#[cfg(test)] -mod test_utils; - diff --git a/Stwo_wrapper/crates/prover/src/core/vcs/ops.rs b/Stwo_wrapper/crates/prover/src/core/vcs/ops.rs deleted file mode 100644 index 14093e5..0000000 --- a/Stwo_wrapper/crates/prover/src/core/vcs/ops.rs +++ /dev/null @@ -1,47 +0,0 @@ -use std::fmt::Debug; - -use serde::{Deserialize, Serialize}; - -use crate::core::backend::{Col, ColumnOps}; -use crate::core::fields::m31::BaseField; -use crate::core::vcs::hash::Hash; - -/// A Merkle node hash is a hash of: -/// [left_child_hash, right_child_hash], column0_value, column1_value, ... -/// "[]" denotes optional values. -/// The largest Merkle layer has no left and right child hashes. The rest of the layers have -/// children hashes. -/// At each layer, the tree may have multiple columns of the same length as the layer. -/// Each node in that layer contains one value from each column. -pub trait MerkleHasher: Debug + Default + Clone { - type Hash: Hash; - /// Hashes a single Merkle node. See [MerkleHasher] for more details. - fn hash_node( - children_hashes: Option<(Self::Hash, Self::Hash)>, - column_values: &[BaseField], - ) -> Self::Hash; -} - -/// Trait for performing Merkle operations on a commitment scheme. -pub trait MerkleOps: - ColumnOps + ColumnOps + for<'de> Deserialize<'de> + Serialize -{ - /// Commits on an entire layer of the Merkle tree. - /// See [MerkleHasher] for more details. - /// - /// The layer has 2^`log_size` nodes that need to be hashed. The topmost layer has 1 node, - /// which is a hash of 2 children and some columns. - /// - /// `prev_layer` is the previous layer of the Merkle tree, if this is not the leaf layer. - /// That layer is assumed to have 2^(`log_size`+1) nodes. - /// - /// `columns` are the extra columns that need to be hashed in each node. - /// They are assumed to be of size 2^`log_size`. - /// - /// Returns the next Merkle layer hashes. - fn commit_on_layer( - log_size: u32, - prev_layer: Option<&Col>, - columns: &[&Col], - ) -> Col; -} diff --git a/Stwo_wrapper/crates/prover/src/core/vcs/poseidon252_merkle.rs b/Stwo_wrapper/crates/prover/src/core/vcs/poseidon252_merkle.rs deleted file mode 100644 index 6441b71..0000000 --- a/Stwo_wrapper/crates/prover/src/core/vcs/poseidon252_merkle.rs +++ /dev/null @@ -1,182 +0,0 @@ -use num_traits::Zero; -use serde::{Deserialize, Serialize}; -use starknet_crypto::{poseidon_hash, poseidon_hash_many}; -use starknet_ff::FieldElement as FieldElement252; - -use super::ops::MerkleHasher; -use crate::core::channel::{MerkleChannel, Poseidon252Channel}; -use crate::core::fields::m31::BaseField; -use crate::core::vcs::hash::Hash; - -const ELEMENTS_IN_BLOCK: usize = 8; - -#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Deserialize, Serialize)] -pub struct Poseidon252MerkleHasher; -impl MerkleHasher for Poseidon252MerkleHasher { - type Hash = FieldElement252; - - fn hash_node( - children_hashes: Option<(Self::Hash, Self::Hash)>, - column_values: &[BaseField], - ) -> Self::Hash { - let n_column_blocks = column_values.len().div_ceil(ELEMENTS_IN_BLOCK); - let values_len = 2 + n_column_blocks; - let mut values = Vec::with_capacity(values_len); - - if let Some((left, right)) = children_hashes { - values.push(left); - values.push(right); - } - - let padding_length = ELEMENTS_IN_BLOCK * n_column_blocks - column_values.len(); - let padded_values = column_values - .iter() - .copied() - .chain(std::iter::repeat(BaseField::zero()).take(padding_length)); - for chunk in padded_values.array_chunks::() { - let mut word = FieldElement252::default(); - for x in chunk { - word = word * FieldElement252::from(2u64.pow(31)) + FieldElement252::from(x.0); - } - values.push(word); - } - poseidon_hash_many(&values) - } -} - -impl Hash for FieldElement252 { - fn to_bytes(&self) -> Vec { - self.to_bytes_be().to_vec() - } - - fn from_bytes(bytes: &[u8]) -> Self { - let mut bytes_array = [0u8; 32]; - bytes_array.copy_from_slice(bytes); - FieldElement252::from_bytes_be(&bytes_array).unwrap() - } -} - -#[derive(Default)] -pub struct Poseidon252MerkleChannel; - -impl MerkleChannel for Poseidon252MerkleChannel { - type C = Poseidon252Channel; - type H = Poseidon252MerkleHasher; - - fn mix_root(channel: &mut Self::C, root: ::Hash) { - channel.update_digest(poseidon_hash(channel.digest(), root)); - } -} - -#[cfg(test)] -mod tests { - use num_traits::Zero; - use starknet_ff::FieldElement as FieldElement252; - - use crate::core::fields::m31::BaseField; - use crate::core::vcs::ops::MerkleHasher; - use crate::core::vcs::poseidon252_merkle::Poseidon252MerkleHasher; - use crate::core::vcs::test_utils::prepare_merkle; - use crate::core::vcs::verifier::MerkleVerificationError; - use crate::m31; - - #[test] - fn test_vector() { - assert_eq!( - Poseidon252MerkleHasher::hash_node(None, &[m31!(0), m31!(1)]), - FieldElement252::from_dec_str( - "2552053700073128806553921687214114320458351061521275103654266875084493044716" - ) - .unwrap() - ); - - assert_eq!( - Poseidon252MerkleHasher::hash_node( - Some((FieldElement252::from(1u32), FieldElement252::from(2u32))), - &[m31!(3)] - ), - FieldElement252::from_dec_str( - "159358216886023795422515519110998391754567506678525778721401012606792642769" - ) - .unwrap() - ); - } - - #[test] - fn test_merkle_success() { - let (queries, decommitment, values, verifier) = prepare_merkle::(); - verifier.verify(queries, values, decommitment).unwrap(); - } - - #[test] - fn test_merkle_invalid_witness() { - let (queries, mut decommitment, values, verifier) = - prepare_merkle::(); - decommitment.hash_witness[4] = FieldElement252::default(); - - assert_eq!( - verifier.verify(queries, values, decommitment).unwrap_err(), - MerkleVerificationError::RootMismatch - ); - } - - #[test] - fn test_merkle_invalid_value() { - let (queries, decommitment, mut values, verifier) = - prepare_merkle::(); - values[3][2] = BaseField::zero(); - - assert_eq!( - verifier.verify(queries, values, decommitment).unwrap_err(), - MerkleVerificationError::RootMismatch - ); - } - - #[test] - fn test_merkle_witness_too_short() { - let (queries, mut decommitment, values, verifier) = - prepare_merkle::(); - decommitment.hash_witness.pop(); - - assert_eq!( - verifier.verify(queries, values, decommitment).unwrap_err(), - MerkleVerificationError::WitnessTooShort - ); - } - - #[test] - fn test_merkle_witness_too_long() { - let (queries, mut decommitment, values, verifier) = - prepare_merkle::(); - decommitment.hash_witness.push(FieldElement252::default()); - - assert_eq!( - verifier.verify(queries, values, decommitment).unwrap_err(), - MerkleVerificationError::WitnessTooLong - ); - } - - #[test] - fn test_merkle_column_values_too_long() { - let (queries, decommitment, mut values, verifier) = - prepare_merkle::(); - values[3].push(BaseField::zero()); - - assert_eq!( - verifier.verify(queries, values, decommitment).unwrap_err(), - MerkleVerificationError::ColumnValuesTooLong - ); - } - - #[test] - fn test_merkle_column_values_too_short() { - let (queries, decommitment, mut values, verifier) = - prepare_merkle::(); - values[3].pop(); - - assert_eq!( - verifier.verify(queries, values, decommitment).unwrap_err(), - MerkleVerificationError::ColumnValuesTooShort - ); - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/vcs/poseidon_bls_merkle.rs b/Stwo_wrapper/crates/prover/src/core/vcs/poseidon_bls_merkle.rs deleted file mode 100644 index 196e734..0000000 --- a/Stwo_wrapper/crates/prover/src/core/vcs/poseidon_bls_merkle.rs +++ /dev/null @@ -1,581 +0,0 @@ -use num_traits::Zero; -use serde::{Deserialize, Serialize}; -use ark_bls12_381::Fr as BlsFr; -use ark_ff::{BigInteger, Field, PrimeField}; - - -use super::ops::MerkleHasher; -use crate::core::channel::{MerkleChannel, PoseidonBLSChannel}; -use crate::core::fields::m31::BaseField; -use crate::core::vcs::hash::Hash; - -const ELEMENTS_IN_BLOCK: usize = 8; - -//Optimize constant to be real constants (no conversion) and merge duplicated code in VCS poseidono -fn poseidon_comp_consts(idx: usize) -> BlsFr { - match idx { - 0 => BlsFr::from_be_bytes_mod_order(&[ - 111, 0, 122, 85, 17, 86, 179, 164, 73, 228, 73, 54, 183, 192, 147, 100, 74, 14, 211, - 63, 51, 234, 204, 198, 40, 233, 66, 232, 54, 193, 168, 117, - ]), - 1 => BlsFr::from_be_bytes_mod_order(&[ - 54, 13, 116, 112, 97, 30, 71, 61, 53, 63, 98, 143, 118, 209, 16, 243, 78, 113, 22, 47, - 49, 0, 59, 112, 87, 83, 140, 37, 150, 66, 99, 3, - ]), - 2 => BlsFr::from_be_bytes_mod_order(&[ - 75, 95, 236, 58, 160, 115, 223, 68, 1, 144, 145, 240, 7, 164, 76, 169, 150, 72, 73, - 101, 247, 3, 109, 206, 62, 157, 9, 119, 237, 205, 192, 246, - ]), - 3 => BlsFr::from_be_bytes_mod_order(&[ - 103, 207, 24, 104, 175, 99, 150, 192, 184, 76, 206, 113, 94, 83, 159, 132, 158, 6, 205, - 28, 56, 58, 197, 176, 97, 0, 199, 107, 204, 151, 58, 17, - ]), - 4 => BlsFr::from_be_bytes_mod_order(&[ - 85, 93, 180, 209, 220, 237, 129, 159, 93, 61, 231, 15, 222, 131, 241, 199, 211, 232, - 201, 137, 104, 229, 22, 162, 58, 119, 26, 92, 156, 130, 87, 170, - ]), - 5 => BlsFr::from_be_bytes_mod_order(&[ - 43, 171, 148, 215, 174, 34, 45, 19, 93, 195, 198, 197, 254, 191, 170, 49, 73, 8, 172, - 47, 18, 235, 224, 111, 189, 183, 66, 19, 191, 99, 24, 139, - ]), - 6 => BlsFr::from_be_bytes_mod_order(&[ - 102, 244, 75, 229, 41, 102, 130, 196, 250, 120, 130, 121, 157, 109, 208, 73, 182, 215, - 210, 201, 80, 204, 249, 140, 242, 229, 13, 109, 30, 187, 119, 194, - ]), - 7 => BlsFr::from_be_bytes_mod_order(&[ - 21, 12, 147, 254, 246, 82, 251, 28, 43, 240, 62, 26, 41, 170, 135, 31, 239, 119, 231, - 215, 54, 118, 108, 93, 9, 57, 217, 39, 83, 204, 93, 200, - ]), - 8 => BlsFr::from_be_bytes_mod_order(&[ - 50, 112, 102, 30, 104, 146, 139, 58, 149, 93, 85, 219, 86, 220, 87, 193, 3, 204, 10, - 96, 20, 30, 137, 78, 20, 37, 157, 206, 83, 119, 130, 178, - ]), - 9 => BlsFr::from_be_bytes_mod_order(&[ - 7, 63, 17, 111, 4, 18, 46, 37, 160, 183, 175, 228, 226, 5, 114, 153, 180, 7, 195, 112, - 242, 181, 161, 204, 206, 159, 185, 255, 195, 69, 175, 179, - ]), - 10 => BlsFr::from_be_bytes_mod_order(&[ - 64, 159, 218, 34, 85, 140, 254, 77, 61, 216, 220, 226, 79, 105, 231, 111, 140, 42, 174, - 177, 221, 15, 9, 214, 94, 101, 76, 113, 243, 42, 162, 63, - ]), - 11 => BlsFr::from_be_bytes_mod_order(&[ - 42, 50, 236, 92, 78, 229, 177, 131, 122, 255, 208, 156, 31, 83, 245, 253, 85, 201, 205, - 32, 97, 174, 147, 202, 142, 186, 215, 111, 199, 21, 84, 216, - ]), - 12 => BlsFr::from_be_bytes_mod_order(&[ - 88, 72, 235, 235, 89, 35, 233, 37, 85, 183, 18, 79, 255, 186, 93, 107, 213, 113, 198, - 249, 132, 25, 94, 185, 207, 211, 163, 232, 235, 85, 177, 212, - ]), - 13 => BlsFr::from_be_bytes_mod_order(&[ - 39, 3, 38, 238, 3, 157, 241, 158, 101, 30, 44, 252, 116, 6, 40, 202, 99, 77, 36, 252, - 110, 37, 89, 242, 45, 140, 203, 226, 146, 239, 238, 173, - ]), - 14 => BlsFr::from_be_bytes_mod_order(&[ - 39, 198, 100, 42, 198, 51, 188, 102, 220, 16, 15, 231, 252, 250, 84, 145, 138, 248, - 149, 188, 224, 18, 241, 130, 160, 104, 252, 55, 193, 130, 226, 116, - ]), - 15 => BlsFr::from_be_bytes_mod_order(&[ - 27, 223, 216, 176, 20, 1, 199, 10, 210, 127, 87, 57, 105, 137, 18, 157, 113, 14, 31, - 182, 171, 151, 106, 69, 156, 161, 134, 130, 226, 109, 127, 249, - ]), - 16 => BlsFr::from_be_bytes_mod_order(&[ - 73, 27, 155, 166, 152, 59, 207, 159, 5, 254, 71, 148, 173, 180, 74, 48, 135, 155, 248, - 40, 150, 98, 225, 245, 125, 144, 246, 114, 65, 78, 138, 74, - ]), - 17 => BlsFr::from_be_bytes_mod_order(&[ - 22, 42, 20, 198, 47, 154, 137, 184, 20, 185, 214, 169, 200, 77, 214, 120, 244, 246, - 251, 63, 144, 84, 211, 115, 200, 50, 216, 36, 38, 26, 53, 234, - ]), - 18 => BlsFr::from_be_bytes_mod_order(&[ - 45, 25, 62, 15, 118, 222, 88, 107, 42, 246, 247, 158, 49, 39, 254, 234, 172, 10, 31, - 199, 30, 44, 240, 192, 247, 152, 36, 102, 123, 91, 107, 236, - ]), - 19 => BlsFr::from_be_bytes_mod_order(&[ - 70, 239, 216, 169, 162, 98, 214, 216, 253, 201, 202, 92, 4, 176, 152, 47, 36, 221, 204, - 110, 152, 99, 136, 90, 106, 115, 42, 57, 6, 160, 123, 149, - ]), - 20 => BlsFr::from_be_bytes_mod_order(&[ - 80, 151, 23, 224, 194, 0, 227, 201, 45, 141, 202, 41, 115, 179, 219, 69, 240, 120, 130, - 148, 53, 26, 208, 122, 231, 92, 187, 120, 6, 147, 167, 152, - ]), - 21 => BlsFr::from_be_bytes_mod_order(&[ - 114, 153, 178, 132, 100, 168, 201, 79, 185, 212, 223, 97, 56, 15, 57, 192, 220, 169, - 194, 192, 20, 17, 135, 137, 226, 39, 37, 40, 32, 240, 27, 252, - ]), - 22 => BlsFr::from_be_bytes_mod_order(&[ - 4, 76, 163, 204, 74, 133, 215, 59, 129, 105, 110, 241, 16, 78, 103, 79, 79, 239, 248, - 41, 132, 153, 15, 248, 93, 11, 245, 141, 200, 164, 170, 148, - ]), - 23 => BlsFr::from_be_bytes_mod_order(&[ - 28, 186, 242, 179, 113, 218, 198, 168, 29, 4, 83, 65, 109, 62, 35, 92, 184, 217, 226, - 212, 243, 20, 244, 111, 97, 152, 120, 95, 12, 214, 185, 175, - ]), - 24 => BlsFr::from_be_bytes_mod_order(&[ - 29, 91, 39, 119, 105, 44, 32, 91, 14, 108, 73, 208, 97, 182, 181, 244, 41, 60, 74, 176, - 56, 253, 187, 220, 52, 62, 7, 97, 15, 63, 237, 229, - ]), - 25 => BlsFr::from_be_bytes_mod_order(&[ - 86, 174, 124, 122, 82, 147, 189, 194, 62, 133, 225, 105, 140, 129, 199, 127, 138, 216, - 140, 75, 51, 165, 120, 4, 55, 173, 4, 124, 110, 219, 89, 186, - ]), - 26 => BlsFr::from_be_bytes_mod_order(&[ - 46, 155, 219, 186, 61, 211, 75, 255, 170, 48, 83, 91, 221, 116, 154, 126, 6, 169, 173, - 176, 193, 230, 249, 98, 246, 14, 151, 27, 141, 115, 176, 79, - ]), - 27 => BlsFr::from_be_bytes_mod_order(&[ - 45, 225, 24, 134, 177, 128, 17, 202, 139, 213, 186, 227, 105, 105, 41, 159, 222, 64, - 251, 226, 109, 4, 123, 5, 3, 90, 19, 102, 31, 34, 65, 139, - ]), - 28 => BlsFr::from_be_bytes_mod_order(&[ - 46, 7, 222, 23, 128, 184, 167, 13, 13, 91, 74, 63, 24, 65, 220, 216, 42, 185, 57, 92, - 68, 155, 233, 71, 188, 153, 136, 132, 186, 150, 167, 33, - ]), - 29 => BlsFr::from_be_bytes_mod_order(&[ - 15, 105, 241, 133, 77, 32, 202, 12, 187, 219, 99, 219, 213, 45, 173, 22, 37, 4, 64, - 169, 157, 107, 138, 243, 130, 94, 76, 43, 183, 73, 37, 202, - ]), - 30 => BlsFr::from_be_bytes_mod_order(&[ - 93, 201, 135, 49, 142, 110, 89, 193, 175, 184, 123, 101, 93, 213, 140, 193, 210, 46, - 81, 58, 5, 131, 140, 212, 88, 93, 4, 177, 53, 185, 87, 202, - ]), - 31 => BlsFr::from_be_bytes_mod_order(&[ - 72, 183, 37, 117, 133, 113, 201, 223, 108, 1, 220, 99, 154, 133, 240, 114, 151, 105, - 107, 27, 182, 120, 99, 58, 41, 220, 145, 222, 149, 239, 83, 246, - ]), - 32 => BlsFr::from_be_bytes_mod_order(&[ - 94, 86, 94, 8, 192, 130, 16, 153, 37, 107, 86, 73, 14, 174, 225, 213, 115, 175, 209, - 11, 182, 209, 125, 19, 202, 78, 92, 97, 27, 42, 55, 24, - ]), - 33 => BlsFr::from_be_bytes_mod_order(&[ - 46, 177, 178, 84, 23, 254, 23, 103, 13, 19, 93, 198, 57, 251, 9, 164, 108, 229, 17, 53, - 7, 249, 109, 233, 129, 108, 5, 148, 34, 220, 112, 94, - ]), - 34 => BlsFr::from_be_bytes_mod_order(&[ - 17, 92, 208, 160, 100, 60, 251, 152, 140, 36, 203, 68, 195, 250, 180, 138, 255, 54, - 198, 97, 210, 108, 196, 45, 184, 177, 189, 244, 149, 59, 216, 44, - ]), - 35 => BlsFr::from_be_bytes_mod_order(&[ - 38, 202, 41, 63, 123, 44, 70, 45, 6, 109, 115, 120, 185, 153, 134, 139, 187, 87, 221, - 241, 78, 15, 149, 138, 222, 128, 22, 18, 49, 29, 4, 205, - ]), - 36 => BlsFr::from_be_bytes_mod_order(&[ - 65, 71, 64, 13, 142, 26, 172, 207, 49, 26, 107, 91, 118, 32, 17, 171, 62, 69, 50, 110, - 77, 75, 157, 226, 105, 146, 129, 107, 153, 197, 40, 172, - ]), - 37 => BlsFr::from_be_bytes_mod_order(&[ - 107, 13, 183, 220, 204, 75, 161, 178, 104, 246, 189, 204, 77, 55, 40, 72, 212, 167, 41, - 118, 194, 104, 234, 48, 81, 154, 47, 115, 230, 219, 77, 85, - ]), - 38 => BlsFr::from_be_bytes_mod_order(&[ - 23, 191, 27, 147, 196, 199, 224, 26, 42, 131, 10, 161, 98, 65, 44, 217, 15, 22, 11, - 249, 247, 30, 150, 127, 245, 32, 157, 20, 178, 72, 32, 202, - ]), - 39 => BlsFr::from_be_bytes_mod_order(&[ - 75, 67, 28, 217, 239, 237, 188, 148, 207, 30, 202, 111, 158, 156, 24, 57, 208, 230, - 106, 139, 255, 168, 200, 70, 76, 172, 129, 163, 157, 60, 248, 241, - ]), - 40 => BlsFr::from_be_bytes_mod_order(&[ - 53, 180, 26, 122, 196, 243, 197, 113, 162, 79, 132, 86, 54, 156, 133, 223, 224, 60, 3, - 84, 189, 140, 253, 56, 5, 200, 111, 46, 125, 194, 147, 197, - ]), - 41 => BlsFr::from_be_bytes_mod_order(&[ - 59, 20, 128, 8, 5, 35, 196, 57, 67, 89, 39, 153, 72, 73, 190, 169, 100, 225, 77, 59, - 235, 45, 221, 222, 114, 172, 21, 106, 244, 53, 208, 158, - ]), - 42 => BlsFr::from_be_bytes_mod_order(&[ - 44, 198, 129, 0, 49, 220, 27, 13, 73, 80, 133, 109, 201, 7, 213, 117, 8, 226, 134, 68, - 42, 45, 62, 178, 39, 22, 24, 216, 116, 177, 76, 109, - ]), - 43 => BlsFr::from_be_bytes_mod_order(&[ - 111, 65, 65, 200, 64, 28, 90, 57, 91, 166, 121, 14, 253, 113, 199, 12, 4, 175, 234, 6, - 195, 201, 40, 38, 188, 171, 221, 92, 181, 71, 125, 81, - ]), - 44 => BlsFr::from_be_bytes_mod_order(&[ - 37, 189, 187, 237, 161, 189, 232, 193, 5, 150, 24, 226, 175, 210, 239, 153, 158, 81, - 122, 169, 59, 120, 52, 29, 145, 243, 24, 192, 159, 12, 181, 102, - ]), - 45 => BlsFr::from_be_bytes_mod_order(&[ - 57, 42, 74, 135, 88, 224, 110, 232, 185, 95, 51, 194, 93, 222, 138, 192, 42, 94, 208, - 162, 123, 97, 146, 108, 198, 49, 52, 135, 7, 63, 127, 123, - ]), - 46 => BlsFr::from_be_bytes_mod_order(&[ - 39, 42, 85, 135, 138, 8, 68, 43, 154, 166, 17, 31, 77, 224, 9, 72, 94, 106, 111, 209, - 93, 184, 147, 101, 231, 187, 206, 240, 46, 181, 134, 108, - ]), - 47 => BlsFr::from_be_bytes_mod_order(&[ - 99, 30, 193, 214, 210, 141, 217, 232, 36, 238, 137, 163, 7, 48, 174, 247, 171, 70, 58, - 207, 201, 209, 132, 179, 85, 170, 5, 253, 105, 56, 234, 181, - ]), - 48 => BlsFr::from_be_bytes_mod_order(&[ - 78, 182, 253, 161, 15, 208, 251, 222, 2, 199, 68, 155, 251, 221, 195, 91, 205, 130, 37, - 231, 229, 195, 131, 58, 8, 24, 161, 0, 64, 157, 198, 242, - ]), - 49 => BlsFr::from_be_bytes_mod_order(&[ - 45, 91, 48, 139, 12, 240, 44, 223, 239, 161, 60, 78, 96, 226, 98, 57, 166, 235, 186, 1, - 22, 148, 221, 18, 155, 146, 91, 60, 91, 33, 224, 226, - ]), - 50 => BlsFr::from_be_bytes_mod_order(&[ - 22, 84, 159, 198, 175, 47, 59, 114, 221, 93, 41, 61, 114, 226, 229, 242, 68, 223, 244, - 47, 24, 180, 108, 86, 239, 56, 197, 124, 49, 22, 115, 172, - ]), - 51 => BlsFr::from_be_bytes_mod_order(&[ - 66, 51, 38, 119, 255, 53, 156, 94, 141, 184, 54, 217, 245, 251, 84, 130, 46, 57, 189, - 94, 34, 52, 11, 185, 186, 151, 91, 161, 169, 43, 227, 130, - ]), - 52 => BlsFr::from_be_bytes_mod_order(&[ - 73, 215, 210, 192, 180, 73, 229, 23, 155, 197, 204, 195, 180, 76, 96, 117, 217, 132, - 155, 86, 16, 70, 95, 9, 234, 114, 93, 220, 151, 114, 58, 148, - ]), - 53 => BlsFr::from_be_bytes_mod_order(&[ - 100, 194, 15, 185, 13, 122, 0, 56, 49, 117, 124, 196, 198, 34, 111, 110, 73, 133, 252, - 158, 203, 65, 107, 159, 104, 76, 160, 53, 29, 150, 121, 4, - ]), - 54 => BlsFr::from_be_bytes_mod_order(&[ - 89, 207, 244, 13, 232, 59, 82, 180, 27, 196, 67, 215, 151, 149, 16, 215, 113, 201, 64, - 185, 117, 140, 168, 32, 254, 115, 181, 200, 213, 88, 9, 52, - ]), - 55 => BlsFr::from_be_bytes_mod_order(&[ - 83, 219, 39, 49, 115, 12, 57, 176, 78, 221, 135, 95, 227, 183, 200, 130, 128, 130, 133, - 205, 188, 98, 29, 122, 244, 248, 13, 213, 62, 187, 113, 176, - ]), - 56 => BlsFr::from_be_bytes_mod_order(&[ - 27, 16, 187, 122, 130, 175, 206, 57, 250, 105, 195, 162, 173, 82, 247, 109, 118, 57, - 130, 101, 52, 66, 3, 17, 155, 113, 38, 217, 180, 104, 96, 223, - ]), - 57 => BlsFr::from_be_bytes_mod_order(&[ - 86, 27, 96, 18, 214, 102, 191, 225, 121, 196, 221, 127, 132, 205, 209, 83, 21, 150, - 211, 170, 199, 197, 112, 12, 235, 49, 159, 145, 4, 106, 99, 201, - ]), - 58 => BlsFr::from_be_bytes_mod_order(&[ - 15, 30, 117, 5, 235, 217, 29, 47, 199, 156, 45, 247, 220, 152, 163, 190, 209, 179, 105, - 104, 186, 4, 5, 192, 144, 210, 127, 106, 0, 183, 223, 200, - ]), - 59 => BlsFr::from_be_bytes_mod_order(&[ - 47, 49, 63, 175, 13, 63, 97, 135, 83, 122, 116, 151, 163, 180, 63, 70, 121, 127, 214, - 227, 241, 142, 177, 202, 255, 69, 119, 86, 184, 25, 187, 32, - ]), - 60 => BlsFr::from_be_bytes_mod_order(&[ - 58, 92, 187, 109, 228, 80, 180, 129, 250, 60, 166, 28, 14, 209, 91, 197, 92, 173, 17, - 235, 240, 247, 206, 184, 240, 188, 62, 115, 46, 203, 38, 246, - ]), - 61 => BlsFr::from_be_bytes_mod_order(&[ - 104, 29, 147, 65, 27, 248, 206, 99, 246, 113, 106, 239, 189, 14, 36, 80, 100, 84, 192, - 52, 142, 227, 143, 171, 235, 38, 71, 2, 113, 76, 207, 148, - ]), - 62 => BlsFr::from_be_bytes_mod_order(&[ - 81, 120, 233, 64, 245, 0, 4, 49, 38, 70, 180, 54, 114, 127, 14, 128, 167, 184, 242, - 233, 238, 31, 220, 103, 124, 72, 49, 167, 103, 39, 119, 251, - ]), - 63 => BlsFr::from_be_bytes_mod_order(&[ - 61, 171, 84, 188, 155, 239, 104, 141, 217, 32, 134, 226, 83, 180, 57, 214, 81, 186, - 166, 226, 15, 137, 43, 98, 134, 85, 39, 203, 202, 145, 89, 130, - ]), - 64 => BlsFr::from_be_bytes_mod_order(&[ - 75, 60, 231, 83, 17, 33, 143, 154, 233, 5, 248, 78, 170, 91, 43, 56, 24, 68, 139, 191, - 57, 114, 225, 170, 214, 157, 227, 33, 0, 144, 21, 208, - ]), - 65 => BlsFr::from_be_bytes_mod_order(&[ - 6, 219, 251, 66, 185, 121, 136, 77, 226, 128, 211, 22, 112, 18, 63, 116, 76, 36, 179, - 59, 65, 15, 239, 212, 54, 128, 69, 172, 242, 183, 26, 227, - ]), - 66 => BlsFr::from_be_bytes_mod_order(&[ - 6, 141, 107, 70, 8, 170, 232, 16, 198, 240, 57, 234, 25, 115, 166, 62, 184, 210, 222, - 114, 227, 210, 201, 236, 167, 252, 50, 210, 47, 24, 185, 211, - ]), - 67 => BlsFr::from_be_bytes_mod_order(&[ - 76, 92, 37, 69, 137, 169, 42, 54, 8, 74, 87, 211, 177, 217, 100, 39, 138, 204, 126, 79, - 232, 246, 159, 41, 85, 149, 79, 39, 167, 156, 235, 239, - ]), - 68 => BlsFr::from_be_bytes_mod_order(&[ - 108, 186, 197, 225, 112, 9, 132, 235, 195, 45, 161, 91, 75, 185, 104, 63, 170, 186, - 181, 95, 103, 204, 196, 247, 29, 149, 96, 179, 71, 90, 119, 235, - ]), - 69 => BlsFr::from_be_bytes_mod_order(&[ - 70, 3, 196, 3, 187, 250, 154, 23, 115, 138, 92, 98, 120, 234, 171, 28, 55, 236, 48, - 176, 115, 122, 162, 64, 159, 196, 137, 128, 105, 235, 152, 60, - ]), - 70 => BlsFr::from_be_bytes_mod_order(&[ - 104, 148, 231, 226, 43, 44, 29, 92, 112, 167, 18, 166, 52, 90, 230, 177, 146, 169, 200, - 51, 169, 35, 76, 49, 197, 106, 172, 209, 107, 194, 241, 0, - ]), - 71 => BlsFr::from_be_bytes_mod_order(&[ - 91, 226, 203, 188, 68, 5, 58, 208, 138, 250, 77, 30, 171, 199, 243, 210, 49, 238, 167, - 153, 185, 63, 34, 110, 144, 91, 125, 77, 101, 197, 142, 187, - ]), - 72 => BlsFr::from_be_bytes_mod_order(&[ - 88, 229, 95, 40, 123, 69, 58, 152, 8, 98, 74, 140, 42, 53, 61, 82, 141, 160, 247, 231, - 19, 165, 198, 208, 215, 113, 30, 71, 6, 63, 166, 17, - ]), - 73 => BlsFr::from_be_bytes_mod_order(&[ - 54, 110, 191, 175, 163, 173, 56, 28, 14, 226, 88, 201, 184, 253, 252, 205, 184, 104, - 167, 215, 225, 241, 246, 154, 43, 93, 252, 197, 87, 37, 85, 223, - ]), - 74 => BlsFr::from_be_bytes_mod_order(&[ - 69, 118, 106, 183, 40, 150, 140, 100, 47, 144, 217, 124, 207, 85, 4, 221, 193, 5, 24, - 168, 25, 235, 188, 196, 208, 156, 63, 93, 120, 77, 103, 206, - ]), - 75 => BlsFr::from_be_bytes_mod_order(&[ - 57, 103, 143, 101, 81, 47, 30, 228, 4, 219, 48, 36, 244, 29, 63, 86, 126, 246, 109, - 137, 208, 68, 208, 34, 230, 188, 34, 158, 149, 188, 118, 177, - ]), - 76 => BlsFr::from_be_bytes_mod_order(&[ - 70, 58, 237, 29, 47, 31, 149, 94, 48, 120, 190, 91, 247, 191, 196, 111, 192, 235, 140, - 81, 85, 25, 6, 168, 134, 143, 24, 255, 174, 48, 207, 79, - ]), - 77 => BlsFr::from_be_bytes_mod_order(&[ - 33, 102, 143, 1, 106, 128, 99, 192, 213, 139, 119, 80, 163, 188, 47, 225, 207, 130, - 194, 95, 153, 220, 1, 164, 229, 52, 200, 143, 229, 61, 133, 254, - ]), - 78 => BlsFr::from_be_bytes_mod_order(&[ - 57, 208, 9, 148, 168, 165, 4, 106, 27, 199, 73, 54, 62, 152, 167, 104, 227, 77, 234, - 86, 67, 159, 225, 149, 75, 239, 66, 155, 197, 51, 22, 8, - ]), - 79 => BlsFr::from_be_bytes_mod_order(&[ - 77, 127, 93, 205, 120, 236, 233, 169, 51, 152, 77, 227, 44, 11, 72, 250, 194, 187, 169, - 31, 38, 25, 150, 184, 233, 209, 2, 23, 115, 189, 7, 204, - ]), - _ => BlsFr::ZERO, - } -} - -pub fn poseidon_hash_bls(x: BlsFr, y: BlsFr) -> BlsFr { - let mut state = [x, y, BlsFr::ZERO]; - poseidon_permute_comp_bls(&mut state); - state[0] + x -} - -pub fn poseidon_permute_comp_bls(state: &mut [BlsFr; 3]) { - let mut idx = 0; - mix(state); - - // Full rounds - for _ in 0..4 { - round_comp(state, idx, true); - idx += 3; - } - - // Partial rounds - for _ in 0..56 { - round_comp(state, idx, false); - idx += 1; - } - - // Full rounds - for _ in 0..4 { - round_comp(state, idx, true); - idx += 3; - } -} - -#[inline] -fn round_comp(state: &mut [BlsFr; 3], idx: usize, full: bool) { - if full { - state[0] += poseidon_comp_consts(idx); - state[1] += poseidon_comp_consts(idx + 1); - state[2] += poseidon_comp_consts(idx + 2); - // Optimize multiplication - state[0] = state[0] * state[0] * state[0] * state[0] * state[0]; - state[1] = state[1] * state[1] * state[1] * state[1] * state[1]; - state[2] = state[2] * state[2] * state[2] * state[2] * state[2]; - } else { - state[0] += poseidon_comp_consts(idx); - state[2] = state[2] * state[2] * state[2] * state[2] * state[2]; - } - mix(state); -} - -#[inline(always)] -fn mix(state: &mut [BlsFr; 3]) { - state[0] = state[0] + state[1] + state[2]; - state[1] = state[0] + state[1]; - state[2] = state[0] + state[2]; -} - -pub fn poseidon_hash_many_bls(msgs: &[BlsFr]) -> BlsFr { - let mut state = [BlsFr::ZERO, BlsFr::ZERO, BlsFr::ZERO]; - let mut iter = msgs.chunks_exact(2); - - for msg in iter.by_ref() { - state[0] += msg[0]; - state[1] += msg[1]; - poseidon_permute_comp_bls(&mut state); - } - let r = iter.remainder(); - if r.len() == 1 { - state[0] += r[0]; - } - state[r.len()] += BlsFr::ONE; - poseidon_permute_comp_bls(&mut state); - - state[0] -} - -#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Deserialize, Serialize)] -pub struct PoseidonBLSMerkleHasher; -impl MerkleHasher for PoseidonBLSMerkleHasher { - type Hash = BlsFr; - - fn hash_node( - children_hashes: Option<(Self::Hash, Self::Hash)>, - column_values: &[BaseField], - ) -> Self::Hash { - let n_column_blocks = column_values.len().div_ceil(ELEMENTS_IN_BLOCK); - let values_len = 2 + n_column_blocks; - let mut values = Vec::with_capacity(values_len); - - if let Some((left, right)) = children_hashes { - values.push(left); - values.push(right); - } - - let padding_length = ELEMENTS_IN_BLOCK * n_column_blocks - column_values.len(); - let padded_values = column_values - .iter() - .copied() - .chain(std::iter::repeat(BaseField::zero()).take(padding_length)); - for chunk in padded_values.array_chunks::() { - let mut word = BlsFr::default(); - for x in chunk { - word = word * BlsFr::from(2u64.pow(31)) + BlsFr::from(x.0); - } - values.push(word); - } - poseidon_hash_many_bls(&values) - } -} - -impl Hash for BlsFr { - fn to_bytes(&self) -> Vec { - self.into_bigint().to_bytes_be() - } - - fn from_bytes(bytes: &[u8]) -> Self { - BlsFr::from_be_bytes_mod_order(bytes) - } -} - -#[derive(Default)] -pub struct PoseidonBLSMerkleChannel; - -impl MerkleChannel for PoseidonBLSMerkleChannel { - type C = PoseidonBLSChannel; - type H = PoseidonBLSMerkleHasher; - - fn mix_root(channel: &mut Self::C, root: ::Hash) { - channel.update_digest(poseidon_hash_bls(channel.digest(), root)); - } -} - -#[cfg(test)] -mod tests { - use num_traits::Zero; - use ark_bls12_381::Fr as BlsFr; - - use crate::core::fields::m31::BaseField; - //use crate::core::vcs::ops::MerkleHasher; - use crate::core::vcs::poseidon_bls_merkle::PoseidonBLSMerkleHasher; - use crate::core::vcs::test_utils::prepare_merkle; - use crate::core::vcs::verifier::MerkleVerificationError; - //use crate::m31; - - // TODO: Redo test vectors for bls - /*#[test] - fn test_vector() { - assert_eq!( - PoseidonBLSMerkleHasher::hash_node(None, &[m31!(0), m31!(1)]), - BlsFr::from_str( - "2552053700073128806553921687214114320458351061521275103654266875084493044716" - ) - .unwrap() - ); - - assert_eq!( - Poseidon252MerkleHasher::hash_node( - Some((FieldElement252::from(1u32), FieldElement252::from(2u32))), - &[m31!(3)] - ), - FieldElement252::from_dec_str( - "159358216886023795422515519110998391754567506678525778721401012606792642769" - ) - .unwrap() - ); - }*/ - - #[test] - fn test_merkle_success() { - let (queries, decommitment, values, verifier) = prepare_merkle::(); - verifier.verify(queries, values, decommitment).unwrap(); - } - - #[test] - fn test_merkle_invalid_witness() { - let (queries, mut decommitment, values, verifier) = - prepare_merkle::(); - decommitment.hash_witness[4] = BlsFr::default(); - - assert_eq!( - verifier.verify(queries, values, decommitment).unwrap_err(), - MerkleVerificationError::RootMismatch - ); - } - - #[test] - fn test_merkle_invalid_value() { - let (queries, decommitment, mut values, verifier) = - prepare_merkle::(); - values[3][2] = BaseField::zero(); - - assert_eq!( - verifier.verify(queries, values, decommitment).unwrap_err(), - MerkleVerificationError::RootMismatch - ); - } - - #[test] - fn test_merkle_witness_too_short() { - let (queries, mut decommitment, values, verifier) = - prepare_merkle::(); - decommitment.hash_witness.pop(); - - assert_eq!( - verifier.verify(queries, values, decommitment).unwrap_err(), - MerkleVerificationError::WitnessTooShort - ); - } - - #[test] - fn test_merkle_witness_too_long() { - let (queries, mut decommitment, values, verifier) = - prepare_merkle::(); - decommitment.hash_witness.push(BlsFr::default()); - - assert_eq!( - verifier.verify(queries, values, decommitment).unwrap_err(), - MerkleVerificationError::WitnessTooLong - ); - } - - #[test] - fn test_merkle_column_values_too_long() { - let (queries, decommitment, mut values, verifier) = - prepare_merkle::(); - values[3].push(BaseField::zero()); - - assert_eq!( - verifier.verify(queries, values, decommitment).unwrap_err(), - MerkleVerificationError::ColumnValuesTooLong - ); - } - - #[test] - fn test_merkle_column_values_too_short() { - let (queries, decommitment, mut values, verifier) = - prepare_merkle::(); - values[3].pop(); - - assert_eq!( - verifier.verify(queries, values, decommitment).unwrap_err(), - MerkleVerificationError::ColumnValuesTooShort - ); - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/vcs/prover.rs b/Stwo_wrapper/crates/prover/src/core/vcs/prover.rs deleted file mode 100644 index c2fd63d..0000000 --- a/Stwo_wrapper/crates/prover/src/core/vcs/prover.rs +++ /dev/null @@ -1,223 +0,0 @@ -use std::cmp::Reverse; -use std::collections::BTreeMap; - -use itertools::Itertools; - -use super::ops::{MerkleHasher, MerkleOps}; -use super::utils::{next_decommitment_node, option_flatten_peekable}; -use crate::core::backend::{Col, Column}; -use crate::core::fields::m31::BaseField; -use crate::core::utils::PeekableExt; -use crate::core::ColumnVec; - -pub struct MerkleProver, H: MerkleHasher> { - /// Layers of the Merkle tree. - /// The first layer is the root layer. - /// The last layer is the largest layer. - /// See [MerkleOps::commit_on_layer] for more details. - pub layers: Vec>, -} -/// The MerkleProver struct represents a prover for a Merkle commitment scheme. -/// It is generic over the types `B` and `H`, which represent the Merkle operations and Merkle -/// hasher respectively. -impl, H: MerkleHasher> MerkleProver { - /// Commits to columns. - /// Columns must be of power of 2 sizes. - /// - /// # Arguments - /// - /// * `columns` - A vector of references to columns. - /// - /// # Panics - /// - /// This function will panic if the columns are not sorted in descending order or if the columns - /// vector is empty. - /// - /// # Returns - /// - /// A new instance of `MerkleProver` with the committed layers. - pub fn commit(columns: Vec<&Col>) -> Self { - assert!(!columns.is_empty()); - - let columns = &mut columns - .into_iter() - .sorted_by_key(|c| Reverse(c.len())) - .peekable(); - let mut layers: Vec> = Vec::new(); - - let max_log_size = columns.peek().unwrap().len().ilog2(); - for log_size in (0..=max_log_size).rev() { - // Take columns of the current log_size. - let layer_columns = columns - .peek_take_while(|column| column.len().ilog2() == log_size) - .collect_vec(); - - layers.push(B::commit_on_layer(log_size, layers.last(), &layer_columns)); - } - layers.reverse(); - Self { layers } - } - - /// Decommits to columns on the given queries. - /// Queries are given as indices to the largest column. - /// - /// # Arguments - /// - /// * `queries_per_log_size` - A map from log_size to a vector of queries for columns of that - /// log_size. - /// * `columns` - A vector of references to columns. - /// - /// # Returns - /// - /// A tuple containing: - /// * A vector of vectors of queried values for each column, in the order of the input columns. - /// * A `MerkleDecommitment` containing the hash and column witnesses. - pub fn decommit( - &self, - queries_per_log_size: BTreeMap>, - columns: Vec<&Col>, - ) -> (ColumnVec>, MerkleDecommitment) { - // Check that queries are sorted and deduped. - // TODO(andrew): Consider using a Queries struct to prevent this. - for queries in queries_per_log_size.values() { - assert!( - queries.windows(2).all(|w| w[0] < w[1]), - "Queries are not sorted." - ); - } - - // Prepare output buffers. - let mut queried_values_by_layer = vec![]; - let mut decommitment = MerkleDecommitment::empty(); - - // Sort columns by layer. - let mut columns_by_layer = columns - .iter() - .sorted_by_key(|c| Reverse(c.len())) - .peekable(); - - let mut last_layer_queries = vec![]; - for layer_log_size in (0..self.layers.len() as u32).rev() { - // Prepare write buffer for queried values to the current layer. - let mut layer_queried_values = vec![]; - - // Prepare write buffer for queries to the current layer. This will propagate to the - // next layer. - let mut layer_total_queries = vec![]; - - // Each layer node is a hash of column values as previous layer hashes. - // Prepare the relevant columns and previous layer hashes to read from. - let layer_columns = columns_by_layer - .peek_take_while(|column| column.len().ilog2() == layer_log_size) - .collect_vec(); - let previous_layer_hashes = self.layers.get(layer_log_size as usize + 1); - - // Queries to this layer come from queried node in the previous layer and queried - // columns in this one. - let mut prev_layer_queries = last_layer_queries.into_iter().peekable(); - let mut layer_column_queries = - option_flatten_peekable(queries_per_log_size.get(&layer_log_size)); - - // Merge previous layer queries and column queries. - while let Some(node_index) = - next_decommitment_node(&mut prev_layer_queries, &mut layer_column_queries) - { - if let Some(previous_layer_hashes) = previous_layer_hashes { - // If the left child was not computed, add it to the witness. - if prev_layer_queries.next_if_eq(&(2 * node_index)).is_none() { - decommitment - .hash_witness - .push(previous_layer_hashes.at(2 * node_index)); - } - - // If the right child was not computed, add it to the witness. - if prev_layer_queries - .next_if_eq(&(2 * node_index + 1)) - .is_none() - { - decommitment - .hash_witness - .push(previous_layer_hashes.at(2 * node_index + 1)); - } - } - - // If the column values were queried, return them. - let node_values = layer_columns.iter().map(|c| c.at(node_index)); - if layer_column_queries.next_if_eq(&node_index).is_some() { - layer_queried_values.push(node_values.collect_vec()); - } else { - // Otherwise, add them to the witness. - decommitment.column_witness.extend(node_values); - } - - layer_total_queries.push(node_index); - } - - queried_values_by_layer.push(layer_queried_values); - - // Propagate queries to the next layer. - last_layer_queries = layer_total_queries; - } - queried_values_by_layer.reverse(); - - // Rearrange returned queried values according to input, and not by layer. - let queried_values = Self::rearrange_queried_values(queried_values_by_layer, columns); - - (queried_values, decommitment) - } - - /// Given queried values by layer, rearranges in the order of input columns. - fn rearrange_queried_values( - queried_values_by_layer: Vec>>, - columns: Vec<&Col>, - ) -> Vec> { - // Turn each column queried values into an iterator. - let mut queried_values_by_layer = queried_values_by_layer - .into_iter() - .map(|layer_results| { - layer_results - .into_iter() - .map(|x| x.into_iter()) - .collect_vec() - }) - .collect_vec(); - - // For each input column, fetch the queried values from the corresponding layer. - let queried_values = columns - .iter() - .map(|column| { - queried_values_by_layer - .get_mut(column.len().ilog2() as usize) - .unwrap() - .iter_mut() - .map(|x| x.next().unwrap()) - .collect_vec() - }) - .collect_vec(); - queried_values - } - - pub fn root(&self) -> H::Hash { - self.layers.first().unwrap().at(0) - } -} - -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd)] -pub struct MerkleDecommitment { - /// Hash values that the verifier needs but cannot deduce from previous computations, in the - /// order they are needed. - pub hash_witness: Vec, - /// Column values that the verifier needs but cannot deduce from previous computations, in the - /// order they are needed. - /// This complements the column values that were queried. These must be supplied directly to - /// the verifier. - pub column_witness: Vec, -} -impl MerkleDecommitment { - fn empty() -> Self { - Self { - hash_witness: Vec::new(), - column_witness: Vec::new(), - } - } -} diff --git a/Stwo_wrapper/crates/prover/src/core/vcs/test_utils.rs b/Stwo_wrapper/crates/prover/src/core/vcs/test_utils.rs deleted file mode 100644 index 8fc535b..0000000 --- a/Stwo_wrapper/crates/prover/src/core/vcs/test_utils.rs +++ /dev/null @@ -1,60 +0,0 @@ -use std::collections::BTreeMap; - -use itertools::Itertools; -use rand::rngs::SmallRng; -use rand::{Rng, SeedableRng}; - -use super::ops::{MerkleHasher, MerkleOps}; -use super::prover::MerkleDecommitment; -use super::verifier::MerkleVerifier; -use crate::core::backend::CpuBackend; -use crate::core::fields::m31::BaseField; -use crate::core::vcs::prover::MerkleProver; - -pub type TestData = ( - BTreeMap>, - MerkleDecommitment, - Vec>, - MerkleVerifier, -); - -pub fn prepare_merkle() -> TestData -where - CpuBackend: MerkleOps, -{ - const N_COLS: usize = 10; - const N_QUERIES: usize = 3; - let log_size_range = 3..5; - - let mut rng = SmallRng::seed_from_u64(0); - let log_sizes = (0..N_COLS) - .map(|_| rng.gen_range(log_size_range.clone())) - .collect_vec(); - let cols = log_sizes - .iter() - .map(|&log_size| { - (0..(1 << log_size)) - .map(|_| BaseField::from(rng.gen_range(0..(1 << 30)))) - .collect_vec() - }) - .collect_vec(); - let merkle = MerkleProver::::commit(cols.iter().collect_vec()); - - let mut queries = BTreeMap::>::new(); - for log_size in log_size_range.rev() { - let layer_queries = (0..N_QUERIES) - .map(|_| rng.gen_range(0..(1 << log_size))) - .sorted() - .dedup() - .collect_vec(); - queries.insert(log_size, layer_queries); - } - - let (values, decommitment) = merkle.decommit(queries.clone(), cols.iter().collect_vec()); - - let verifier = MerkleVerifier { - root: merkle.root(), - column_log_sizes: log_sizes, - }; - (queries, decommitment, values, verifier) -} diff --git a/Stwo_wrapper/crates/prover/src/core/vcs/utils.rs b/Stwo_wrapper/crates/prover/src/core/vcs/utils.rs deleted file mode 100644 index 2f89f40..0000000 --- a/Stwo_wrapper/crates/prover/src/core/vcs/utils.rs +++ /dev/null @@ -1,20 +0,0 @@ -use std::iter::Peekable; - -/// Fetches the next node that needs to be decommited in the current Merkle layer. -pub fn next_decommitment_node( - prev_queries: &mut Peekable>, - layer_queries: &mut Peekable>, -) -> Option { - prev_queries - .peek() - .map(|q| *q / 2) - .into_iter() - .chain(layer_queries.peek().into_iter().copied()) - .min() -} - -pub fn option_flatten_peekable<'a, I: IntoIterator>( - a: Option, -) -> Peekable as IntoIterator>::IntoIter>>> { - a.into_iter().flatten().copied().peekable() -} diff --git a/Stwo_wrapper/crates/prover/src/core/vcs/verifier.rs b/Stwo_wrapper/crates/prover/src/core/vcs/verifier.rs deleted file mode 100644 index 08428de..0000000 --- a/Stwo_wrapper/crates/prover/src/core/vcs/verifier.rs +++ /dev/null @@ -1,191 +0,0 @@ -use std::cmp::Reverse; -use std::collections::BTreeMap; -use itertools::Itertools; -use thiserror::Error; - -use super::ops::MerkleHasher; -use super::prover::MerkleDecommitment; -use super::utils::{next_decommitment_node, option_flatten_peekable}; -use crate::core::fields::m31::BaseField; -use crate::core::utils::PeekableExt; -use crate::core::ColumnVec; - -// TODO(spapini): This struct is not necessary. Make it a function on decommitment? -pub struct MerkleVerifier { - pub root: H::Hash, - pub column_log_sizes: Vec, -} -impl MerkleVerifier { - pub fn new(root: H::Hash, column_log_sizes: Vec) -> Self { - Self { - root, - column_log_sizes, - } - } - /// Verifies the decommitment of the columns. - /// - /// # Arguments - /// - /// * `queries_per_log_size` - A map from log_size to a vector of queries for columns of that - /// log_size. - /// * `queried_values` - A vector of vectors of queried values. For each column, there is a - /// vector of queried values to that column. - /// * `decommitment` - The decommitment object containing the witness and column values. - /// - /// # Errors - /// - /// Returns an error if any of the following conditions are met: - /// - /// * The witness is too long (not fully consumed). - /// * The witness is too short (missing values). - /// * The column values are too long (not fully consumed). - /// * The column values are too short (missing values). - /// * The computed root does not match the expected root. - /// - /// # Panics - /// - /// This function will panic if the `values` vector is not sorted in descending order based on - /// the `log_size` of the columns. - /// - /// # Returns - /// - /// Returns `Ok(())` if the decommitment is successfully verified. - pub fn verify( - &self, - queries_per_log_size: BTreeMap>, - queried_values: ColumnVec>, - decommitment: MerkleDecommitment, - ) -> Result<(), MerkleVerificationError> { - let max_log_size = self.column_log_sizes.iter().max().copied().unwrap_or(0); - - // Prepare read buffers. - let mut queried_values_by_layer = self - .column_log_sizes - .iter() - .copied() - .zip( - queried_values - .into_iter() - .map(|column_values| column_values.into_iter()), - ) - .sorted_by_key(|(log_size, _)| Reverse(*log_size)) - .peekable(); - let mut hash_witness = decommitment.hash_witness.into_iter(); - let mut column_witness = decommitment.column_witness.into_iter(); - - let mut last_layer_hashes: Option> = None; - for layer_log_size in (0..=max_log_size).rev() { - // Prepare read buffer for queried values to the current layer. - let mut layer_queried_values = queried_values_by_layer - .peek_take_while(|(log_size, _)| *log_size == layer_log_size) - .collect_vec(); - let n_columns_in_layer = layer_queried_values.len(); - - // Prepare write buffer for queries to the current layer. This will propagate to the - // next layer. - let mut layer_total_queries = vec![]; - - // Queries to this layer come from queried node in the previous layer and queried - // columns in this one. - let mut prev_layer_queries = last_layer_hashes - .iter() - .flatten() - .map(|(q, _)| *q) - .collect_vec() - .into_iter() - .peekable(); - let mut prev_layer_hashes = last_layer_hashes.as_ref().map(|x| x.iter().peekable()); - let mut layer_column_queries = - option_flatten_peekable(queries_per_log_size.get(&layer_log_size)); - - // Merge previous layer queries and column queries. - while let Some(node_index) = - next_decommitment_node(&mut prev_layer_queries, &mut layer_column_queries) - { - prev_layer_queries - .peek_take_while(|q| q / 2 == node_index) - .for_each(drop); - - let node_hashes = prev_layer_hashes - .as_mut() - .map(|prev_layer_hashes| { - { - // If the left child was not computed, read it from the witness. - let left_hash = prev_layer_hashes - .next_if(|(index, _)| *index == 2 * node_index) - .map(|(_, hash)| Ok(*hash)) - .unwrap_or_else(|| { - hash_witness - .next() - .ok_or(MerkleVerificationError::WitnessTooShort) - })?; - - // If the right child was not computed, read it to from the witness. - let right_hash = prev_layer_hashes - .next_if(|(index, _)| *index == 2 * node_index + 1) - .map(|(_, hash)| Ok(*hash)) - .unwrap_or_else(|| { - hash_witness - .next() - .ok_or(MerkleVerificationError::WitnessTooShort) - })?; - Ok((left_hash, right_hash)) - } - }) - .transpose()?; - - // If the column values were queried, read them from `queried_value`. - let node_values = if layer_column_queries.next_if_eq(&node_index).is_some() { - layer_queried_values - .iter_mut() - .map(|(_, ref mut column_queries)| { - column_queries - .next() - .ok_or(MerkleVerificationError::ColumnValuesTooShort) - }) - .collect::, _>>()? - } else { - // Otherwise, read them from the witness. - (&mut column_witness).take(n_columns_in_layer).collect_vec() - }; - if node_values.len() != n_columns_in_layer { - return Err(MerkleVerificationError::WitnessTooShort); - } - layer_total_queries.push((node_index, H::hash_node(node_hashes, &node_values))); - } - if !layer_queried_values.iter().all(|(_, c)| c.is_empty()) { - return Err(MerkleVerificationError::ColumnValuesTooLong); - } - last_layer_hashes = Some(layer_total_queries); - } - - // Check that all witnesses and values have been consumed. - if !hash_witness.is_empty() { - return Err(MerkleVerificationError::WitnessTooLong); - } - if !column_witness.is_empty() { - return Err(MerkleVerificationError::WitnessTooLong); - } - - let [(_, computed_root)] = last_layer_hashes.unwrap().try_into().unwrap(); - if computed_root != self.root { - return Err(MerkleVerificationError::RootMismatch); - } - - Ok(()) - } -} - -#[derive(Clone, Copy, Debug, Error, PartialEq, Eq)] -pub enum MerkleVerificationError { - #[error("Witness is too short.")] - WitnessTooShort, - #[error("Witness is too long.")] - WitnessTooLong, - #[error("Column values are too long.")] - ColumnValuesTooLong, - #[error("Column values are too short.")] - ColumnValuesTooShort, - #[error("Root mismatch.")] - RootMismatch, -} diff --git a/Stwo_wrapper/crates/prover/src/examples/blake/air.rs b/Stwo_wrapper/crates/prover/src/examples/blake/air.rs deleted file mode 100644 index e655ee6..0000000 --- a/Stwo_wrapper/crates/prover/src/examples/blake/air.rs +++ /dev/null @@ -1,483 +0,0 @@ -use std::simd::u32x16; - -use itertools::{chain, multiunzip, Itertools}; -use num_traits::Zero; -use serde::Serialize; -use tracing::{span, Level}; - -use super::round::{blake_round_info, BlakeRoundComponent, BlakeRoundEval}; -use super::scheduler::{BlakeSchedulerComponent, BlakeSchedulerEval}; -use super::xor_table::{XorTableComponent, XorTableEval}; -use crate::constraint_framework::TraceLocationAllocator; -use crate::core::air::{Component, ComponentProver}; -use crate::core::backend::simd::m31::LOG_N_LANES; -use crate::core::backend::simd::SimdBackend; -use crate::core::backend::BackendForChannel; -use crate::core::channel::{Channel, MerkleChannel}; -use crate::core::fields::qm31::SecureField; -use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig, TreeVec}; -use crate::core::poly::circle::{CanonicCoset, PolyOps}; -use crate::core::prover::{prove, verify, StarkProof, VerificationError}; -use crate::core::vcs::ops::MerkleHasher; -use crate::examples::blake::round::RoundElements; -use crate::examples::blake::scheduler::{self, blake_scheduler_info, BlakeElements, BlakeInput}; -use crate::examples::blake::{ - round, xor_table, BlakeXorElements, XorAccums, N_ROUNDS, ROUND_LOG_SPLIT, -}; - -#[derive(Serialize)] -pub struct BlakeStatement0 { - log_size: u32, -} -impl BlakeStatement0 { - fn log_sizes(&self) -> TreeVec> { - let mut sizes = vec![]; - sizes.push( - blake_scheduler_info() - .mask_offsets - .as_cols_ref() - .map_cols(|_| self.log_size), - ); - for l in ROUND_LOG_SPLIT { - sizes.push( - blake_round_info() - .mask_offsets - .as_cols_ref() - .map_cols(|_| self.log_size + l), - ); - } - sizes.push(xor_table::trace_sizes::<12, 4>()); - sizes.push(xor_table::trace_sizes::<9, 2>()); - sizes.push(xor_table::trace_sizes::<8, 2>()); - sizes.push(xor_table::trace_sizes::<7, 2>()); - sizes.push(xor_table::trace_sizes::<4, 0>()); - - TreeVec::concat_cols(sizes.into_iter()) - } - fn mix_into(&self, channel: &mut impl Channel) { - // TODO(spapini): Do this better. - channel.mix_nonce(self.log_size as u64); - } -} - -pub struct AllElements { - blake_elements: BlakeElements, - round_elements: RoundElements, - xor_elements: BlakeXorElements, -} -impl AllElements { - pub fn draw(channel: &mut impl Channel) -> Self { - Self { - blake_elements: BlakeElements::draw(channel), - round_elements: RoundElements::draw(channel), - xor_elements: BlakeXorElements::draw(channel), - } - } -} - -pub struct BlakeStatement1 { - scheduler_claimed_sum: SecureField, - round_claimed_sums: Vec, - xor12_claimed_sum: SecureField, - xor9_claimed_sum: SecureField, - xor8_claimed_sum: SecureField, - xor7_claimed_sum: SecureField, - xor4_claimed_sum: SecureField, -} -impl BlakeStatement1 { - fn mix_into(&self, channel: &mut impl Channel) { - channel.mix_felts( - &chain![ - [ - self.scheduler_claimed_sum, - self.xor12_claimed_sum, - self.xor9_claimed_sum, - self.xor8_claimed_sum, - self.xor7_claimed_sum, - self.xor4_claimed_sum - ], - self.round_claimed_sums.clone() - ] - .collect_vec(), - ) - } -} - -pub struct BlakeProof { - stmt0: BlakeStatement0, - stmt1: BlakeStatement1, - stark_proof: StarkProof, -} - -pub struct BlakeComponents { - scheduler_component: BlakeSchedulerComponent, - round_components: Vec, - xor12: XorTableComponent<12, 4>, - xor9: XorTableComponent<9, 2>, - xor8: XorTableComponent<8, 2>, - xor7: XorTableComponent<7, 2>, - xor4: XorTableComponent<4, 0>, -} -impl BlakeComponents { - fn new(stmt0: &BlakeStatement0, all_elements: &AllElements, stmt1: &BlakeStatement1) -> Self { - let tree_span_provider = &mut TraceLocationAllocator::default(); - Self { - scheduler_component: BlakeSchedulerComponent::new( - tree_span_provider, - BlakeSchedulerEval { - log_size: stmt0.log_size, - blake_lookup_elements: all_elements.blake_elements.clone(), - round_lookup_elements: all_elements.round_elements.clone(), - claimed_sum: stmt1.scheduler_claimed_sum, - }, - ), - round_components: ROUND_LOG_SPLIT - .iter() - .zip(stmt1.round_claimed_sums.clone()) - .map(|(l, claimed_sum)| { - BlakeRoundComponent::new( - tree_span_provider, - BlakeRoundEval { - log_size: stmt0.log_size + l, - xor_lookup_elements: all_elements.xor_elements.clone(), - round_lookup_elements: all_elements.round_elements.clone(), - claimed_sum, - }, - ) - }) - .collect(), - xor12: XorTableComponent::new( - tree_span_provider, - XorTableEval { - lookup_elements: all_elements.xor_elements.xor12.clone(), - claimed_sum: stmt1.xor12_claimed_sum, - }, - ), - xor9: XorTableComponent::new( - tree_span_provider, - XorTableEval { - lookup_elements: all_elements.xor_elements.xor9.clone(), - claimed_sum: stmt1.xor9_claimed_sum, - }, - ), - xor8: XorTableComponent::new( - tree_span_provider, - XorTableEval { - lookup_elements: all_elements.xor_elements.xor8.clone(), - claimed_sum: stmt1.xor8_claimed_sum, - }, - ), - xor7: XorTableComponent::new( - tree_span_provider, - XorTableEval { - lookup_elements: all_elements.xor_elements.xor7.clone(), - claimed_sum: stmt1.xor7_claimed_sum, - }, - ), - xor4: XorTableComponent::new( - tree_span_provider, - XorTableEval { - lookup_elements: all_elements.xor_elements.xor4.clone(), - claimed_sum: stmt1.xor4_claimed_sum, - }, - ), - } - } - fn components(&self) -> Vec<&dyn Component> { - chain![ - [&self.scheduler_component as &dyn Component], - self.round_components.iter().map(|c| c as &dyn Component), - [ - &self.xor12 as &dyn Component, - &self.xor9 as &dyn Component, - &self.xor8 as &dyn Component, - &self.xor7 as &dyn Component, - &self.xor4 as &dyn Component, - ] - ] - .collect() - } - - fn component_provers(&self) -> Vec<&dyn ComponentProver> { - chain![ - [&self.scheduler_component as &dyn ComponentProver], - self.round_components - .iter() - .map(|c| c as &dyn ComponentProver), - [ - &self.xor12 as &dyn ComponentProver, - &self.xor9 as &dyn ComponentProver, - &self.xor8 as &dyn ComponentProver, - &self.xor7 as &dyn ComponentProver, - &self.xor4 as &dyn ComponentProver, - ] - ] - .collect() - } -} - -#[allow(unused)] -pub fn prove_blake(log_size: u32, config: PcsConfig) -> (BlakeProof) -where - SimdBackend: BackendForChannel, -{ - assert!(log_size >= LOG_N_LANES); - assert_eq!( - ROUND_LOG_SPLIT.map(|x| (1 << x)).into_iter().sum::() as usize, - N_ROUNDS - ); - - // Precompute twiddles. - let span = span!(Level::INFO, "Precompute twiddles").entered(); - const XOR_TABLE_MAX_LOG_SIZE: u32 = 16; - let log_max_rows = - (log_size + *ROUND_LOG_SPLIT.iter().max().unwrap()).max(XOR_TABLE_MAX_LOG_SIZE); - let twiddles = SimdBackend::precompute_twiddles( - CanonicCoset::new(log_max_rows + 1 + config.fri_config.log_blowup_factor) - .circle_domain() - .half_coset, - ); - span.exit(); - - // Prepare inputs. - let blake_inputs = (0..(1 << (log_size - LOG_N_LANES))) - .map(|i| { - let v = [u32x16::from_array(std::array::from_fn(|j| (i + 2 * j) as u32)); 16]; - let m = [u32x16::from_array(std::array::from_fn(|j| (i + 2 * j + 1) as u32)); 16]; - BlakeInput { v, m } - }) - .collect_vec(); - - // Setup protocol. - let channel = &mut MC::C::default(); - let commitment_scheme = &mut CommitmentSchemeProver::new(config, &twiddles); - - let span = span!(Level::INFO, "Trace").entered(); - - // Scheduler. - let (scheduler_trace, scheduler_lookup_data, round_inputs) = - scheduler::gen_trace(log_size, &blake_inputs); - - // Rounds. - let mut xor_accums = XorAccums::default(); - let mut rest = &round_inputs[..]; - // Split round inputs to components, according to [ROUND_LOG_SPLIT]. - let (round_traces, round_lookup_datas): (Vec<_>, Vec<_>) = - multiunzip(ROUND_LOG_SPLIT.map(|l| { - let (cur_inputs, r) = rest.split_at(1 << (log_size - LOG_N_LANES + l)); - rest = r; - round::generate_trace(log_size + l, cur_inputs, &mut xor_accums) - })); - - // Xor tables. - let (xor_trace12, xor_lookup_data12) = xor_table::generate_trace(xor_accums.xor12); - let (xor_trace9, xor_lookup_data9) = xor_table::generate_trace(xor_accums.xor9); - let (xor_trace8, xor_lookup_data8) = xor_table::generate_trace(xor_accums.xor8); - let (xor_trace7, xor_lookup_data7) = xor_table::generate_trace(xor_accums.xor7); - let (xor_trace4, xor_lookup_data4) = xor_table::generate_trace(xor_accums.xor4); - - // Statement0. - let stmt0 = BlakeStatement0 { log_size }; - stmt0.mix_into(channel); - - // Trace commitment. - let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals( - chain![ - scheduler_trace, - round_traces.into_iter().flatten(), - xor_trace12, - xor_trace9, - xor_trace8, - xor_trace7, - xor_trace4, - ] - .collect_vec(), - ); - tree_builder.commit(channel); - span.exit(); - - // Draw lookup element. - let all_elements = AllElements::draw(channel); - - // Interaction trace. - let span = span!(Level::INFO, "Interaction").entered(); - let (scheduler_trace, scheduler_claimed_sum) = scheduler::gen_interaction_trace( - log_size, - scheduler_lookup_data, - &all_elements.round_elements, - &all_elements.blake_elements, - ); - - let (round_traces, round_claimed_sums): (Vec<_>, Vec<_>) = multiunzip( - ROUND_LOG_SPLIT - .iter() - .zip(round_lookup_datas) - .map(|(l, lookup_data)| { - round::generate_interaction_trace( - log_size + l, - lookup_data, - &all_elements.xor_elements, - &all_elements.round_elements, - ) - }), - ); - - let (xor_trace12, xor12_claimed_sum) = - xor_table::generate_interaction_trace(xor_lookup_data12, &all_elements.xor_elements.xor12); - let (xor_trace9, xor9_claimed_sum) = - xor_table::generate_interaction_trace(xor_lookup_data9, &all_elements.xor_elements.xor9); - let (xor_trace8, xor8_claimed_sum) = - xor_table::generate_interaction_trace(xor_lookup_data8, &all_elements.xor_elements.xor8); - let (xor_trace7, xor7_claimed_sum) = - xor_table::generate_interaction_trace(xor_lookup_data7, &all_elements.xor_elements.xor7); - let (xor_trace4, xor4_claimed_sum) = - xor_table::generate_interaction_trace(xor_lookup_data4, &all_elements.xor_elements.xor4); - - let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals( - chain![ - scheduler_trace, - round_traces.into_iter().flatten(), - xor_trace12, - xor_trace9, - xor_trace8, - xor_trace7, - xor_trace4, - ] - .collect_vec(), - ); - - // Statement1. - let stmt1 = BlakeStatement1 { - scheduler_claimed_sum, - round_claimed_sums, - xor12_claimed_sum, - xor9_claimed_sum, - xor8_claimed_sum, - xor7_claimed_sum, - xor4_claimed_sum, - }; - stmt1.mix_into(channel); - tree_builder.commit(channel); - span.exit(); - - // Constant trace. - let span = span!(Level::INFO, "Constant Trace").entered(); - let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals( - chain![ - xor_table::generate_constant_trace::<12, 4>(), - xor_table::generate_constant_trace::<9, 2>(), - xor_table::generate_constant_trace::<8, 2>(), - xor_table::generate_constant_trace::<7, 2>(), - xor_table::generate_constant_trace::<4, 0>(), - ] - .collect_vec(), - ); - tree_builder.commit(channel); - span.exit(); - - assert_eq!( - commitment_scheme - .polynomials() - .as_cols_ref() - .map_cols(|c| c.log_size()) - .0, - stmt0.log_sizes().0 - ); - - // Prove constraints. - let components = BlakeComponents::new(&stmt0, &all_elements, &stmt1); - let stark_proof = - prove::(&components.component_provers(), channel, commitment_scheme) - .unwrap(); - - BlakeProof { - stmt0, - stmt1, - stark_proof, - } -} - -#[allow(unused)] -pub fn verify_blake( - BlakeProof { - stmt0, - stmt1, - stark_proof, - }: BlakeProof, - config: PcsConfig, -) -> Result<(), VerificationError> { - let channel = &mut MC::C::default(); - let commitment_scheme = &mut CommitmentSchemeVerifier::::new(config); - - let log_sizes = stmt0.log_sizes(); - - // Trace. - stmt0.mix_into(channel); - commitment_scheme.commit(stark_proof.commitments[0], &log_sizes[0], channel); - - // Draw interaction elements. - let all_elements = AllElements::draw(channel); - - // Interaction trace. - stmt1.mix_into(channel); - commitment_scheme.commit(stark_proof.commitments[1], &log_sizes[1], channel); - - // Constant trace. - commitment_scheme.commit(stark_proof.commitments[2], &log_sizes[2], channel); - - let components = BlakeComponents::new(&stmt0, &all_elements, &stmt1); - - // Check that all sums are correct. - let total_sum = stmt1.scheduler_claimed_sum - + stmt1.round_claimed_sums.iter().sum::() - + stmt1.xor12_claimed_sum - + stmt1.xor9_claimed_sum - + stmt1.xor8_claimed_sum - + stmt1.xor7_claimed_sum - + stmt1.xor4_claimed_sum; - - // TODO(spapini): Add inputs to sum, and constraint them. - assert_eq!(total_sum, SecureField::zero()); - - verify( - &components.components(), - channel, - commitment_scheme, - stark_proof, - ) -} - -#[cfg(test)] -mod tests { - use std::env; - - use crate::core::pcs::PcsConfig; - use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; - use crate::examples::blake::air::{prove_blake, verify_blake}; - - // Note: this test is slow. Only run in release. - #[cfg_attr(not(feature = "slow-tests"), ignore)] - #[test_log::test] - fn test_simd_blake_prove() { - // Note: To see time measurement, run test with - // LOG_N_INSTANCES=16 RUST_LOG_SPAN_EVENTS=enter,close RUST_LOG=info RUSTFLAGS=" - // -C target-cpu=native -C target-feature=+avx512f" cargo test --release - // test_simd_blake_prove -- --nocapture --ignored - - // Get from environment variable: - let log_n_instances = env::var("LOG_N_INSTANCES") - .unwrap_or_else(|_| "6".to_string()) - .parse::() - .unwrap(); - let config = PcsConfig::default(); - - // Prove. - let proof = prove_blake::(log_n_instances, config); - - // Verify. - verify_blake::(proof, config).unwrap(); - } -} diff --git a/Stwo_wrapper/crates/prover/src/examples/blake/mod.rs b/Stwo_wrapper/crates/prover/src/examples/blake/mod.rs deleted file mode 100644 index 6fbe6d8..0000000 --- a/Stwo_wrapper/crates/prover/src/examples/blake/mod.rs +++ /dev/null @@ -1,126 +0,0 @@ -//! AIR for blake2s and blake3. -//! See - -use std::fmt::Debug; -use std::ops::{Add, AddAssign, Mul, Sub}; -use std::simd::u32x16; - -use xor_table::{XorAccumulator, XorElements}; - -use crate::core::backend::simd::m31::PackedBaseField; -use crate::core::channel::Channel; -use crate::core::fields::m31::BaseField; -use crate::core::fields::FieldExpOps; - -mod air; -mod round; -mod scheduler; -mod xor_table; - -const STATE_SIZE: usize = 16; -const MESSAGE_SIZE: usize = 16; -const N_FELTS_IN_U32: usize = 2; -const N_ROUND_INPUT_FELTS: usize = (STATE_SIZE + STATE_SIZE + MESSAGE_SIZE) * N_FELTS_IN_U32; - -// Parameters for Blake2s. Change these for blake3. -const N_ROUNDS: usize = 10; -/// A splitting N_ROUNDS into several powers of 2. -const ROUND_LOG_SPLIT: [u32; 2] = [3, 1]; - -#[derive(Default)] -struct XorAccums { - xor12: XorAccumulator<12, 4>, - xor9: XorAccumulator<9, 2>, - xor8: XorAccumulator<8, 2>, - xor7: XorAccumulator<7, 2>, - xor4: XorAccumulator<4, 0>, -} -impl XorAccums { - fn add_input(&mut self, w: u32, a: u32x16, b: u32x16) { - match w { - 12 => self.xor12.add_input(a, b), - 9 => self.xor9.add_input(a, b), - 8 => self.xor8.add_input(a, b), - 7 => self.xor7.add_input(a, b), - 4 => self.xor4.add_input(a, b), - _ => panic!("Invalid w"), - } - } -} - -#[derive(Clone)] -pub struct BlakeXorElements { - xor12: XorElements, - xor9: XorElements, - xor8: XorElements, - xor7: XorElements, - xor4: XorElements, -} -impl BlakeXorElements { - fn draw(channel: &mut impl Channel) -> Self { - Self { - xor12: XorElements::draw(channel), - xor9: XorElements::draw(channel), - xor8: XorElements::draw(channel), - xor7: XorElements::draw(channel), - xor4: XorElements::draw(channel), - } - } - fn dummy() -> Self { - Self { - xor12: XorElements::dummy(), - xor9: XorElements::dummy(), - xor8: XorElements::dummy(), - xor7: XorElements::dummy(), - xor4: XorElements::dummy(), - } - } - fn get(&self, w: u32) -> &XorElements { - match w { - 12 => &self.xor12, - 9 => &self.xor9, - 8 => &self.xor8, - 7 => &self.xor7, - 4 => &self.xor4, - _ => panic!("Invalid w"), - } - } -} - -/// Utility for representing a u32 as two field elements, for constraint evaluation. -#[derive(Clone, Copy, Debug)] -struct Fu32 -where - F: FieldExpOps - + Copy - + Debug - + AddAssign - + Add - + Sub - + Mul, -{ - l: F, - h: F, -} -impl Fu32 -where - F: FieldExpOps - + Copy - + Debug - + AddAssign - + Add - + Sub - + Mul, -{ - fn to_felts(self) -> [F; 2] { - [self.l, self.h] - } -} - -/// Utility for splitting a u32 into 2 field elements in trace generation. -fn to_felts(x: &u32x16) -> [PackedBaseField; 2] { - [ - unsafe { PackedBaseField::from_simd_unchecked(x & u32x16::splat(0xffff)) }, - unsafe { PackedBaseField::from_simd_unchecked(x >> 16) }, - ] -} diff --git a/Stwo_wrapper/crates/prover/src/examples/blake/round/constraints.rs b/Stwo_wrapper/crates/prover/src/examples/blake/round/constraints.rs deleted file mode 100644 index 9440944..0000000 --- a/Stwo_wrapper/crates/prover/src/examples/blake/round/constraints.rs +++ /dev/null @@ -1,164 +0,0 @@ -use itertools::{chain, Itertools}; -use num_traits::One; - -use super::{BlakeXorElements, RoundElements}; -use crate::constraint_framework::logup::LogupAtRow; -use crate::constraint_framework::EvalAtRow; -use crate::core::fields::m31::BaseField; -use crate::core::lookups::utils::Fraction; -use crate::examples::blake::{Fu32, STATE_SIZE}; - -const INV16: BaseField = BaseField::from_u32_unchecked(1 << 15); -const TWO: BaseField = BaseField::from_u32_unchecked(2); - -pub struct BlakeRoundEval<'a, E: EvalAtRow> { - pub eval: E, - pub xor_lookup_elements: &'a BlakeXorElements, - pub round_lookup_elements: &'a RoundElements, - pub logup: LogupAtRow<2, E>, -} -impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> { - pub fn eval(mut self) -> E { - let mut v: [Fu32; STATE_SIZE] = std::array::from_fn(|_| self.next_u32()); - let input_v = v; - let m: [Fu32; STATE_SIZE] = std::array::from_fn(|_| self.next_u32()); - - self.g(v.get_many_mut([0, 4, 8, 12]).unwrap(), m[0], m[1]); - self.g(v.get_many_mut([1, 5, 9, 13]).unwrap(), m[2], m[3]); - self.g(v.get_many_mut([2, 6, 10, 14]).unwrap(), m[4], m[5]); - self.g(v.get_many_mut([3, 7, 11, 15]).unwrap(), m[6], m[7]); - self.g(v.get_many_mut([0, 5, 10, 15]).unwrap(), m[8], m[9]); - self.g(v.get_many_mut([1, 6, 11, 12]).unwrap(), m[10], m[11]); - self.g(v.get_many_mut([2, 7, 8, 13]).unwrap(), m[12], m[13]); - self.g(v.get_many_mut([3, 4, 9, 14]).unwrap(), m[14], m[15]); - - // Yield `Round(input_v, output_v, message)`. - self.logup.push_lookup( - &mut self.eval, - -E::EF::one(), - &chain![ - input_v.iter().copied().flat_map(Fu32::to_felts), - v.iter().copied().flat_map(Fu32::to_felts), - m.iter().copied().flat_map(Fu32::to_felts) - ] - .collect_vec(), - self.round_lookup_elements, - ); - - self.logup.finalize(&mut self.eval); - self.eval - } - fn next_u32(&mut self) -> Fu32 { - let l = self.eval.next_trace_mask(); - let h = self.eval.next_trace_mask(); - Fu32 { l, h } - } - fn g(&mut self, v: [&mut Fu32; 4], m0: Fu32, m1: Fu32) { - let [a, b, c, d] = v; - - *a = self.add3_u32_unchecked(*a, *b, m0); - *d = self.xor_rotr16_u32(*a, *d); - *c = self.add2_u32_unchecked(*c, *d); - *b = self.xor_rotr_u32(*b, *c, 12); - *a = self.add3_u32_unchecked(*a, *b, m1); - *d = self.xor_rotr_u32(*a, *d, 8); - *c = self.add2_u32_unchecked(*c, *d); - *b = self.xor_rotr_u32(*b, *c, 7); - } - - /// Adds two u32s, returning the sum. - /// Assumes a, b are properly range checked. - /// The caller is responsible for checking: - /// res.{l,h} not in [2^16, 2^17) or in [-2^16,0) - fn add2_u32_unchecked(&mut self, a: Fu32, b: Fu32) -> Fu32 { - let sl = self.eval.next_trace_mask(); - let sh = self.eval.next_trace_mask(); - - let carry_l = (a.l + b.l - sl) * E::F::from(INV16); - self.eval.add_constraint(carry_l * carry_l - carry_l); - - let carry_h = (a.h + b.h + carry_l - sh) * E::F::from(INV16); - self.eval.add_constraint(carry_h * carry_h - carry_h); - - Fu32 { l: sl, h: sh } - } - - /// Adds three u32s, returning the sum. - /// Assumes a, b, c are properly range checked. - /// Caller is responsible for checking: - /// res.{l,h} not in [2^16, 3*2^16) or in [-2^17,0) - fn add3_u32_unchecked(&mut self, a: Fu32, b: Fu32, c: Fu32) -> Fu32 { - let sl = self.eval.next_trace_mask(); - let sh = self.eval.next_trace_mask(); - - let carry_l = (a.l + b.l + c.l - sl) * E::F::from(INV16); - self.eval - .add_constraint(carry_l * (carry_l - E::F::one()) * (carry_l - E::F::from(TWO))); - - let carry_h = (a.h + b.h + c.h + carry_l - sh) * E::F::from(INV16); - self.eval - .add_constraint(carry_h * (carry_h - E::F::one()) * (carry_h - E::F::from(TWO))); - - Fu32 { l: sl, h: sh } - } - - /// Splits a felt at r. - /// Caller is responsible for checking that the ranges of h * 2^r and l don't overlap. - fn split_unchecked(&mut self, a: E::F, r: u32) -> (E::F, E::F) { - let h = self.eval.next_trace_mask(); - let l = a - h * E::F::from(BaseField::from_u32_unchecked(1 << r)); - (l, h) - } - - /// Checks that a, b are in range, and computes their xor rotated right by `r` bits. - /// Guarantees that all elements are in range. - fn xor_rotr_u32(&mut self, a: Fu32, b: Fu32, r: u32) -> Fu32 { - let (all, alh) = self.split_unchecked(a.l, r); - let (ahl, ahh) = self.split_unchecked(a.h, r); - let (bll, blh) = self.split_unchecked(b.l, r); - let (bhl, bhh) = self.split_unchecked(b.h, r); - - // These also guarantee that all elements are in range. - let [xorll, xorhl] = self.xor2(r, [all, ahl], [bll, bhl]); - let [xorlh, xorhh] = self.xor2(16 - r, [alh, ahh], [blh, bhh]); - - Fu32 { - l: xorhl * E::F::from(BaseField::from_u32_unchecked(1 << (16 - r))) + xorlh, - h: xorll * E::F::from(BaseField::from_u32_unchecked(1 << (16 - r))) + xorhh, - } - } - - /// Checks that a, b are in range, and computes their xor rotated right by 16 bits. - /// Guarantees that all elements are in range. - fn xor_rotr16_u32(&mut self, a: Fu32, b: Fu32) -> Fu32 { - let (all, alh) = self.split_unchecked(a.l, 8); - let (ahl, ahh) = self.split_unchecked(a.h, 8); - let (bll, blh) = self.split_unchecked(b.l, 8); - let (bhl, bhh) = self.split_unchecked(b.h, 8); - - // These also guarantee that all elements are in range. - let [xorll, xorhl] = self.xor2(8, [all, ahl], [bll, bhl]); - let [xorlh, xorhh] = self.xor2(8, [alh, ahh], [blh, bhh]); - - Fu32 { - l: xorhh * E::F::from(BaseField::from_u32_unchecked(1 << 8)) + xorhl, - h: xorlh * E::F::from(BaseField::from_u32_unchecked(1 << 8)) + xorll, - } - } - - /// Checks that a, b are in [0, 2^w) and computes their xor. - fn xor2(&mut self, w: u32, a: [E::F; 2], b: [E::F; 2]) -> [E::F; 2] { - // TODO: Separate lookups by w. - let c = [self.eval.next_trace_mask(), self.eval.next_trace_mask()]; - let lookup_elements = self.xor_lookup_elements.get(w); - let comb0 = lookup_elements.combine::(&[a[0], b[0], c[0]]); - let comb1 = lookup_elements.combine::(&[a[1], b[1], c[1]]); - let frac = Fraction { - numerator: comb0 + comb1, - denominator: comb0 * comb1, - }; - - self.logup.add_frac(&mut self.eval, frac); - c - } -} diff --git a/Stwo_wrapper/crates/prover/src/examples/blake/round/gen.rs b/Stwo_wrapper/crates/prover/src/examples/blake/round/gen.rs deleted file mode 100644 index ba9933b..0000000 --- a/Stwo_wrapper/crates/prover/src/examples/blake/round/gen.rs +++ /dev/null @@ -1,281 +0,0 @@ -use std::simd::u32x16; -use std::vec; - -use itertools::{chain, Itertools}; -use num_traits::One; -use tracing::{span, Level}; - -use super::{BlakeXorElements, RoundElements}; -use crate::constraint_framework::logup::LogupTraceGenerator; -use crate::core::backend::simd::column::BaseColumn; -use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; -use crate::core::backend::simd::qm31::PackedSecureField; -use crate::core::backend::simd::SimdBackend; -use crate::core::backend::{Col, Column}; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; -use crate::core::poly::BitReversedOrder; -use crate::core::ColumnVec; -use crate::examples::blake::round::blake_round_info; -use crate::examples::blake::{to_felts, XorAccums, N_ROUND_INPUT_FELTS, STATE_SIZE}; - -pub struct BlakeRoundLookupData { - /// A vector of (w, [a_col, b_col, c_col]) for each xor lookup. - /// w is the xor width. c_col is the xor col of a_col and b_col. - xor_lookups: Vec<(u32, [BaseColumn; 3])>, - /// A column of round lookup values (v_in, v_out, m). - round_lookup: [BaseColumn; N_ROUND_INPUT_FELTS], -} - -pub struct TraceGenerator { - log_size: u32, - trace: Vec, - xor_lookups: Vec<(u32, [BaseColumn; 3])>, - round_lookup: [BaseColumn; N_ROUND_INPUT_FELTS], -} -impl TraceGenerator { - fn new(log_size: u32) -> Self { - assert!(log_size >= LOG_N_LANES); - let trace = (0..blake_round_info().mask_offsets[0].len()) - .map(|_| unsafe { Col::::uninitialized(1 << log_size) }) - .collect_vec(); - Self { - log_size, - trace, - xor_lookups: vec![], - round_lookup: std::array::from_fn(|_| unsafe { - BaseColumn::uninitialized(1 << log_size) - }), - } - } - - fn gen_row(&mut self, vec_row: usize) -> TraceGeneratorRow<'_> { - TraceGeneratorRow { - gen: self, - col_index: 0, - vec_row, - xor_lookups_index: 0, - } - } -} - -/// Trace generator for the constraints defined at [`super::constraints::BlakeRoundEval`] -struct TraceGeneratorRow<'a> { - gen: &'a mut TraceGenerator, - col_index: usize, - vec_row: usize, - xor_lookups_index: usize, -} -impl<'a> TraceGeneratorRow<'a> { - fn append_felt(&mut self, val: u32x16) { - self.gen.trace[self.col_index].data[self.vec_row] = - unsafe { PackedBaseField::from_simd_unchecked(val) }; - self.col_index += 1; - } - - fn append_u32(&mut self, val: u32x16) { - self.append_felt(val & u32x16::splat(0xffff)); - self.append_felt(val >> 16); - } - - fn generate(&mut self, mut v: [u32x16; 16], m: [u32x16; 16]) { - let input_v = v; - v.iter().for_each(|s| { - self.append_u32(*s); - }); - m.iter().for_each(|s| { - self.append_u32(*s); - }); - - self.g(v.get_many_mut([0, 4, 8, 12]).unwrap(), m[0], m[1]); - self.g(v.get_many_mut([1, 5, 9, 13]).unwrap(), m[2], m[3]); - self.g(v.get_many_mut([2, 6, 10, 14]).unwrap(), m[4], m[5]); - self.g(v.get_many_mut([3, 7, 11, 15]).unwrap(), m[6], m[7]); - self.g(v.get_many_mut([0, 5, 10, 15]).unwrap(), m[8], m[9]); - self.g(v.get_many_mut([1, 6, 11, 12]).unwrap(), m[10], m[11]); - self.g(v.get_many_mut([2, 7, 8, 13]).unwrap(), m[12], m[13]); - self.g(v.get_many_mut([3, 4, 9, 14]).unwrap(), m[14], m[15]); - - chain![input_v.iter(), v.iter(), m.iter()] - .flat_map(to_felts) - .enumerate() - .for_each(|(i, felt)| self.gen.round_lookup[i].data[self.vec_row] = felt); - } - - fn g(&mut self, v: [&mut u32x16; 4], m0: u32x16, m1: u32x16) { - let [a, b, c, d] = v; - - *a = self.add3_u32s(*a, *b, m0); - *d = self.xor_rotr16_u32(*a, *d); - *c = self.add2_u32s(*c, *d); - *b = self.xor_rotr_u32(*b, *c, 12); - *a = self.add3_u32s(*a, *b, m1); - *d = self.xor_rotr_u32(*a, *d, 8); - *c = self.add2_u32s(*c, *d); - *b = self.xor_rotr_u32(*b, *c, 7); - } - - /// Adds two u32s, returning the sum. - fn add2_u32s(&mut self, a: u32x16, b: u32x16) -> u32x16 { - let s = a + b; - self.append_u32(s); - s - } - - /// Adds three u32s, returning the sum. - fn add3_u32s(&mut self, a: u32x16, b: u32x16, c: u32x16) -> u32x16 { - let s = a + b + c; - self.append_u32(s); - s - } - - /// Splits a felt at r. - fn split(&mut self, a: u32x16, r: u32) -> (u32x16, u32x16) { - let h = a >> r; - let l = a & u32x16::splat((1 << r) - 1); - self.append_felt(h); - (l, h) - } - - /// Checks that a, b are in range, and computes their xor rotated right by `r` bits. - fn xor_rotr_u32(&mut self, a: u32x16, b: u32x16, r: u32) -> u32x16 { - let c = a ^ b; - let cr = (c >> r) | (c << (32 - r)); - - let (all, alh) = self.split(a & u32x16::splat(0xffff), r); - let (ahl, ahh) = self.split(a >> 16, r); - let (bll, blh) = self.split(b & u32x16::splat(0xffff), r); - let (bhl, bhh) = self.split(b >> 16, r); - - self.xor(r, all, bll); - self.xor(r, ahl, bhl); - self.xor(16 - r, alh, blh); - self.xor(16 - r, ahh, bhh); - - cr - } - - /// Checks that a, b are in range, and computes their xor rotated right by 16 bits. - fn xor_rotr16_u32(&mut self, a: u32x16, b: u32x16) -> u32x16 { - let c = a ^ b; - let cr = (c >> 16) | (c << 16); - - let (all, alh) = self.split(a & u32x16::splat(0xffff), 8); - let (ahl, ahh) = self.split(a >> 16, 8); - let (bll, blh) = self.split(b & u32x16::splat(0xffff), 8); - let (bhl, bhh) = self.split(b >> 16, 8); - - self.xor(8, all, bll); - self.xor(8, ahl, bhl); - self.xor(8, alh, blh); - self.xor(8, ahh, bhh); - - cr - } - - /// Checks that a, b are in [0, 2^w) and computes their xor. - /// a,b,a^b are assumed to fit in a single felt. - fn xor(&mut self, w: u32, a: u32x16, b: u32x16) -> u32x16 { - let c = a ^ b; - self.append_felt(c); - if self.gen.xor_lookups.len() <= self.xor_lookups_index { - self.gen.xor_lookups.push(( - w, - std::array::from_fn(|_| unsafe { - BaseColumn::uninitialized(1 << self.gen.log_size) - }), - )); - } - self.gen.xor_lookups[self.xor_lookups_index].1[0].data[self.vec_row] = - unsafe { PackedBaseField::from_simd_unchecked(a) }; - self.gen.xor_lookups[self.xor_lookups_index].1[1].data[self.vec_row] = - unsafe { PackedBaseField::from_simd_unchecked(b) }; - self.gen.xor_lookups[self.xor_lookups_index].1[2].data[self.vec_row] = - unsafe { PackedBaseField::from_simd_unchecked(c) }; - self.xor_lookups_index += 1; - c - } -} - -#[derive(Copy, Clone, Default)] -pub struct BlakeRoundInput { - pub v: [u32x16; STATE_SIZE], - pub m: [u32x16; STATE_SIZE], -} - -pub fn generate_trace( - log_size: u32, - inputs: &[BlakeRoundInput], - xor_accum: &mut XorAccums, -) -> ( - ColumnVec>, - BlakeRoundLookupData, -) { - let _span = span!(Level::INFO, "Round Generation").entered(); - let mut generator = TraceGenerator::new(log_size); - - for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { - let mut row_gen = generator.gen_row(vec_row); - let BlakeRoundInput { v, m } = inputs.get(vec_row).copied().unwrap_or_default(); - row_gen.generate(v, m); - for (w, [a, b, _c]) in &generator.xor_lookups { - let a = a.data[vec_row].into_simd(); - let b = b.data[vec_row].into_simd(); - xor_accum.add_input(*w, a, b); - } - } - let domain = CanonicCoset::new(log_size).circle_domain(); - ( - generator - .trace - .into_iter() - .map(|eval| CircleEvaluation::::new(domain, eval)) - .collect_vec(), - BlakeRoundLookupData { - xor_lookups: generator.xor_lookups, - round_lookup: generator.round_lookup, - }, - ) -} - -pub fn generate_interaction_trace( - log_size: u32, - lookup_data: BlakeRoundLookupData, - xor_lookup_elements: &BlakeXorElements, - round_lookup_elements: &RoundElements, -) -> ( - ColumnVec>, - SecureField, -) { - let _span = span!(Level::INFO, "Generate round interaction trace").entered(); - let mut logup_gen = LogupTraceGenerator::new(log_size); - - for [(w0, l0), (w1, l1)] in lookup_data.xor_lookups.array_chunks::<2>() { - let mut col_gen = logup_gen.new_col(); - - #[allow(clippy::needless_range_loop)] - for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { - let p0: PackedSecureField = xor_lookup_elements - .get(*w0) - .combine(&l0.each_ref().map(|l| l.data[vec_row])); - let p1: PackedSecureField = xor_lookup_elements - .get(*w1) - .combine(&l1.each_ref().map(|l| l.data[vec_row])); - col_gen.write_frac(vec_row, p0 + p1, p0 * p1); - } - - col_gen.finalize_col(); - } - - let mut col_gen = logup_gen.new_col(); - #[allow(clippy::needless_range_loop)] - for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { - let p = round_lookup_elements - .combine(&lookup_data.round_lookup.each_ref().map(|l| l.data[vec_row])); - col_gen.write_frac(vec_row, -PackedSecureField::one(), p); - } - col_gen.finalize_col(); - - logup_gen.finalize() -} diff --git a/Stwo_wrapper/crates/prover/src/examples/blake/round/mod.rs b/Stwo_wrapper/crates/prover/src/examples/blake/round/mod.rs deleted file mode 100644 index cf83113..0000000 --- a/Stwo_wrapper/crates/prover/src/examples/blake/round/mod.rs +++ /dev/null @@ -1,110 +0,0 @@ -mod constraints; -mod gen; - -pub use gen::{generate_interaction_trace, generate_trace, BlakeRoundInput}; -use num_traits::Zero; - -use super::{BlakeXorElements, N_ROUND_INPUT_FELTS}; -use crate::constraint_framework::logup::{LogupAtRow, LookupElements}; -use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator}; -use crate::core::fields::qm31::SecureField; - -pub type BlakeRoundComponent = FrameworkComponent; - -pub type RoundElements = LookupElements; - -pub struct BlakeRoundEval { - pub log_size: u32, - pub xor_lookup_elements: BlakeXorElements, - pub round_lookup_elements: RoundElements, - pub claimed_sum: SecureField, -} - -impl FrameworkEval for BlakeRoundEval { - fn log_size(&self) -> u32 { - self.log_size - } - fn max_constraint_log_degree_bound(&self) -> u32 { - self.log_size + 1 - } - fn evaluate(&self, eval: E) -> E { - let blake_eval = constraints::BlakeRoundEval { - eval, - xor_lookup_elements: &self.xor_lookup_elements, - round_lookup_elements: &self.round_lookup_elements, - logup: LogupAtRow::new(1, self.claimed_sum, self.log_size), - }; - blake_eval.eval() - } -} - -pub fn blake_round_info() -> InfoEvaluator { - let component = BlakeRoundEval { - log_size: 1, - xor_lookup_elements: BlakeXorElements::dummy(), - round_lookup_elements: RoundElements::dummy(), - claimed_sum: SecureField::zero(), - }; - component.evaluate(InfoEvaluator::default()) -} - -#[cfg(test)] -mod tests { - use std::simd::Simd; - - use itertools::Itertools; - - use crate::constraint_framework::constant_columns::gen_is_first; - use crate::constraint_framework::FrameworkEval; - use crate::core::poly::circle::CanonicCoset; - use crate::examples::blake::round::r#gen::{ - generate_interaction_trace, generate_trace, BlakeRoundInput, - }; - use crate::examples::blake::round::{BlakeRoundEval, RoundElements}; - use crate::examples::blake::{BlakeXorElements, XorAccums}; - - #[test] - fn test_blake_round() { - use crate::core::pcs::TreeVec; - - const LOG_SIZE: u32 = 10; - - let mut xor_accum = XorAccums::default(); - let (trace, lookup_data) = generate_trace( - LOG_SIZE, - &(0..(1 << LOG_SIZE)) - .map(|_| BlakeRoundInput { - v: std::array::from_fn(|i| Simd::splat(i as u32)), - m: std::array::from_fn(|i| Simd::splat((i + 1) as u32)), - }) - .collect_vec(), - &mut xor_accum, - ); - - let xor_lookup_elements = BlakeXorElements::dummy(); - let round_lookup_elements = RoundElements::dummy(); - let (interaction_trace, claimed_sum) = generate_interaction_trace( - LOG_SIZE, - lookup_data, - &xor_lookup_elements, - &round_lookup_elements, - ); - - let trace = TreeVec::new(vec![trace, interaction_trace, vec![gen_is_first(LOG_SIZE)]]); - let trace_polys = trace.map_cols(|c| c.interpolate()); - - let component = BlakeRoundEval { - log_size: LOG_SIZE, - xor_lookup_elements, - round_lookup_elements, - claimed_sum, - }; - crate::constraint_framework::assert_constraints( - &trace_polys, - CanonicCoset::new(LOG_SIZE), - |eval| { - component.evaluate(eval); - }, - ) - } -} diff --git a/Stwo_wrapper/crates/prover/src/examples/blake/scheduler/constraints.rs b/Stwo_wrapper/crates/prover/src/examples/blake/scheduler/constraints.rs deleted file mode 100644 index 63b3cf6..0000000 --- a/Stwo_wrapper/crates/prover/src/examples/blake/scheduler/constraints.rs +++ /dev/null @@ -1,64 +0,0 @@ -use itertools::{chain, Itertools}; -use num_traits::{One, Zero}; - -use super::BlakeElements; -use crate::constraint_framework::logup::LogupAtRow; -use crate::constraint_framework::EvalAtRow; -use crate::core::vcs::blake2s_ref::SIGMA; -use crate::examples::blake::round::RoundElements; -use crate::examples::blake::{Fu32, N_ROUNDS, STATE_SIZE}; - -pub fn eval_blake_scheduler_constraints( - eval: &mut E, - blake_lookup_elements: &BlakeElements, - round_lookup_elements: &RoundElements, - mut logup: LogupAtRow<2, E>, -) { - let messages: [Fu32; STATE_SIZE] = std::array::from_fn(|_| eval_next_u32(eval)); - let states: [[Fu32; STATE_SIZE]; N_ROUNDS + 1] = - std::array::from_fn(|_| std::array::from_fn(|_| eval_next_u32(eval))); - - // Schedule. - for i in 0..N_ROUNDS { - let input_state = &states[i]; - let output_state = &states[i + 1]; - let round_messages = SIGMA[i].map(|j| messages[j as usize]); - // Use triplet in round lookup. - logup.push_lookup( - eval, - E::EF::one(), - &chain![ - input_state.iter().copied().flat_map(Fu32::to_felts), - output_state.iter().copied().flat_map(Fu32::to_felts), - round_messages.iter().copied().flat_map(Fu32::to_felts) - ] - .collect_vec(), - round_lookup_elements, - ) - } - - let input_state = &states[0]; - let output_state = &states[N_ROUNDS]; - - // TODO(spapini): Support multiplicities. - // TODO(spapini): Change to -1. - logup.push_lookup( - eval, - E::EF::zero(), - &chain![ - input_state.iter().copied().flat_map(Fu32::to_felts), - output_state.iter().copied().flat_map(Fu32::to_felts), - messages.iter().copied().flat_map(Fu32::to_felts) - ] - .collect_vec(), - blake_lookup_elements, - ); - - logup.finalize(eval); -} - -fn eval_next_u32(eval: &mut E) -> Fu32 { - let l = eval.next_trace_mask(); - let h = eval.next_trace_mask(); - Fu32 { l, h } -} diff --git a/Stwo_wrapper/crates/prover/src/examples/blake/scheduler/gen.rs b/Stwo_wrapper/crates/prover/src/examples/blake/scheduler/gen.rs deleted file mode 100644 index 0581b2f..0000000 --- a/Stwo_wrapper/crates/prover/src/examples/blake/scheduler/gen.rs +++ /dev/null @@ -1,171 +0,0 @@ -use std::simd::u32x16; - -use itertools::{chain, Itertools}; -use num_traits::Zero; -use tracing::{span, Level}; - -use super::{blake_scheduler_info, BlakeElements}; -use crate::constraint_framework::logup::LogupTraceGenerator; -use crate::core::backend::simd::column::BaseColumn; -use crate::core::backend::simd::m31::LOG_N_LANES; -use crate::core::backend::simd::qm31::PackedSecureField; -use crate::core::backend::simd::{blake2s, SimdBackend}; -use crate::core::backend::Column; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; -use crate::core::poly::BitReversedOrder; -use crate::core::ColumnVec; -use crate::examples::blake::round::{BlakeRoundInput, RoundElements}; -use crate::examples::blake::{to_felts, N_ROUNDS, N_ROUND_INPUT_FELTS, STATE_SIZE}; - -#[derive(Copy, Clone, Default)] -pub struct BlakeInput { - pub v: [u32x16; STATE_SIZE], - pub m: [u32x16; STATE_SIZE], -} - -pub struct BlakeSchedulerLookupData { - pub round_lookups: [[BaseColumn; N_ROUND_INPUT_FELTS]; N_ROUNDS], - pub blake_lookups: [BaseColumn; N_ROUND_INPUT_FELTS], -} -impl BlakeSchedulerLookupData { - fn new(log_size: u32) -> Self { - Self { - round_lookups: std::array::from_fn(|_| { - std::array::from_fn(|_| unsafe { BaseColumn::uninitialized(1 << log_size) }) - }), - blake_lookups: std::array::from_fn(|_| unsafe { - BaseColumn::uninitialized(1 << log_size) - }), - } - } -} - -pub fn gen_trace( - log_size: u32, - inputs: &[BlakeInput], -) -> ( - ColumnVec>, - BlakeSchedulerLookupData, - Vec, -) { - let _span = span!(Level::INFO, "Scheduler Generation").entered(); - let mut lookup_data = BlakeSchedulerLookupData::new(log_size); - let mut round_inputs = Vec::with_capacity(inputs.len() * N_ROUNDS); - - let mut trace = (0..blake_scheduler_info().mask_offsets[0].len()) - .map(|_| unsafe { BaseColumn::uninitialized(1 << log_size) }) - .collect_vec(); - - for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { - let mut col_index = 0; - - let mut write_u32_array = |x: [u32x16; STATE_SIZE], col_index: &mut usize| { - x.iter().for_each(|x| { - to_felts(x).iter().for_each(|x| { - trace[*col_index].data[vec_row] = *x; - *col_index += 1; - }); - }); - }; - - let BlakeInput { mut v, m } = inputs.get(vec_row).copied().unwrap_or_default(); - let initial_v = v; - write_u32_array(m, &mut col_index); - write_u32_array(v, &mut col_index); - - for r in 0..N_ROUNDS { - let prev_v = v; - blake2s::round(&mut v, m, r); - write_u32_array(v, &mut col_index); - - let round_m = blake2s::SIGMA[r].map(|i| m[i as usize]); - round_inputs.push(BlakeRoundInput { - v: prev_v, - m: round_m, - }); - - chain![ - prev_v.iter().flat_map(to_felts), - v.iter().flat_map(to_felts), - round_m.iter().flat_map(to_felts) - ] - .enumerate() - .for_each(|(i, val)| lookup_data.round_lookups[r][i].data[vec_row] = val); - } - - chain![ - initial_v.iter().flat_map(to_felts), - v.iter().flat_map(to_felts), - m.iter().flat_map(to_felts) - ] - .enumerate() - .for_each(|(i, val)| lookup_data.blake_lookups[i].data[vec_row] = val); - } - - let domain = CanonicCoset::new(log_size).circle_domain(); - let trace = trace - .into_iter() - .map(|eval| CircleEvaluation::::new(domain, eval)) - .collect_vec(); - - (trace, lookup_data, round_inputs) -} -pub fn gen_interaction_trace( - log_size: u32, - lookup_data: BlakeSchedulerLookupData, - round_lookup_elements: &RoundElements, - blake_lookup_elements: &BlakeElements, -) -> ( - ColumnVec>, - SecureField, -) { - let _span = span!(Level::INFO, "Generate scheduler interaction trace").entered(); - - let mut logup_gen = LogupTraceGenerator::new(log_size); - - for [l0, l1] in lookup_data.round_lookups.array_chunks::<2>() { - let mut col_gen = logup_gen.new_col(); - - #[allow(clippy::needless_range_loop)] - for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { - let p0: PackedSecureField = - round_lookup_elements.combine(&l0.each_ref().map(|l| l.data[vec_row])); - let p1: PackedSecureField = - round_lookup_elements.combine(&l1.each_ref().map(|l| l.data[vec_row])); - #[allow(clippy::eq_op)] - col_gen.write_frac(vec_row, p0 + p1, p0 * p1); - } - - col_gen.finalize_col(); - } - - // Last pair. If the number of round is odd (as in blake3), we combine that last round lookup - // with the entire blake lookup. - let mut col_gen = logup_gen.new_col(); - #[allow(clippy::needless_range_loop)] - for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { - let p_blake: PackedSecureField = blake_lookup_elements.combine( - &lookup_data - .blake_lookups - .each_ref() - .map(|l| l.data[vec_row]), - ); - if N_ROUNDS % 2 == 1 { - let p_round: PackedSecureField = round_lookup_elements.combine( - &lookup_data.round_lookups[N_ROUNDS - 1] - .each_ref() - .map(|l| l.data[vec_row]), - ); - // TODO(spapini): Change blake numerator to p_blake - p_round. - col_gen.write_frac(vec_row, p_blake, p_round * p_blake); - } else { - // TODO(spapini): Change numerator to -1. - col_gen.write_frac(vec_row, PackedSecureField::zero(), p_blake); - } - } - col_gen.finalize_col(); - - logup_gen.finalize() -} diff --git a/Stwo_wrapper/crates/prover/src/examples/blake/scheduler/mod.rs b/Stwo_wrapper/crates/prover/src/examples/blake/scheduler/mod.rs deleted file mode 100644 index e8a8c32..0000000 --- a/Stwo_wrapper/crates/prover/src/examples/blake/scheduler/mod.rs +++ /dev/null @@ -1,106 +0,0 @@ -mod constraints; -mod gen; - -use constraints::eval_blake_scheduler_constraints; -pub use gen::{gen_interaction_trace, gen_trace, BlakeInput}; -use num_traits::Zero; - -use super::round::RoundElements; -use super::N_ROUND_INPUT_FELTS; -use crate::constraint_framework::logup::{LogupAtRow, LookupElements}; -use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator}; -use crate::core::fields::qm31::SecureField; - -pub type BlakeSchedulerComponent = FrameworkComponent; - -pub type BlakeElements = LookupElements; - -pub struct BlakeSchedulerEval { - pub log_size: u32, - pub blake_lookup_elements: BlakeElements, - pub round_lookup_elements: RoundElements, - pub claimed_sum: SecureField, -} -impl FrameworkEval for BlakeSchedulerEval { - fn log_size(&self) -> u32 { - self.log_size - } - fn max_constraint_log_degree_bound(&self) -> u32 { - self.log_size + 1 - } - fn evaluate(&self, mut eval: E) -> E { - eval_blake_scheduler_constraints( - &mut eval, - &self.blake_lookup_elements, - &self.round_lookup_elements, - LogupAtRow::new(1, self.claimed_sum, self.log_size), - ); - eval - } -} - -pub fn blake_scheduler_info() -> InfoEvaluator { - let component = BlakeSchedulerEval { - log_size: 1, - blake_lookup_elements: BlakeElements::dummy(), - round_lookup_elements: RoundElements::dummy(), - claimed_sum: SecureField::zero(), - }; - component.evaluate(InfoEvaluator::default()) -} - -#[cfg(test)] -mod tests { - use std::simd::Simd; - - use itertools::Itertools; - - use crate::constraint_framework::FrameworkEval; - use crate::core::poly::circle::CanonicCoset; - use crate::examples::blake::round::RoundElements; - use crate::examples::blake::scheduler::r#gen::{gen_interaction_trace, gen_trace, BlakeInput}; - use crate::examples::blake::scheduler::{BlakeElements, BlakeSchedulerEval}; - - #[test] - fn test_blake_scheduler() { - use crate::core::pcs::TreeVec; - - const LOG_SIZE: u32 = 10; - - let (trace, lookup_data, _round_inputs) = gen_trace( - LOG_SIZE, - &(0..(1 << LOG_SIZE)) - .map(|_| BlakeInput { - v: std::array::from_fn(|i| Simd::splat(i as u32)), - m: std::array::from_fn(|i| Simd::splat((i + 1) as u32)), - }) - .collect_vec(), - ); - - let round_lookup_elements = RoundElements::dummy(); - let blake_lookup_elements = BlakeElements::dummy(); - let (interaction_trace, claimed_sum) = gen_interaction_trace( - LOG_SIZE, - lookup_data, - &round_lookup_elements, - &blake_lookup_elements, - ); - - let trace = TreeVec::new(vec![trace, interaction_trace]); - let trace_polys = trace.map_cols(|c| c.interpolate()); - - let component = BlakeSchedulerEval { - log_size: LOG_SIZE, - blake_lookup_elements, - round_lookup_elements, - claimed_sum, - }; - crate::constraint_framework::assert_constraints( - &trace_polys, - CanonicCoset::new(LOG_SIZE), - |eval| { - component.evaluate(eval); - }, - ) - } -} diff --git a/Stwo_wrapper/crates/prover/src/examples/blake/xor_table/constraints.rs b/Stwo_wrapper/crates/prover/src/examples/blake/xor_table/constraints.rs deleted file mode 100644 index 00a6583..0000000 --- a/Stwo_wrapper/crates/prover/src/examples/blake/xor_table/constraints.rs +++ /dev/null @@ -1,52 +0,0 @@ -use super::{limb_bits, XorElements}; -use crate::constraint_framework::logup::{LogupAtRow, LookupElements}; -use crate::constraint_framework::EvalAtRow; -use crate::core::fields::m31::BaseField; - -/// Constraints for the xor table. -pub struct XorTableEval<'a, E: EvalAtRow, const ELEM_BITS: u32, const EXPAND_BITS: u32> { - pub eval: E, - pub lookup_elements: &'a XorElements, - pub logup: LogupAtRow<2, E>, -} -impl<'a, E: EvalAtRow, const ELEM_BITS: u32, const EXPAND_BITS: u32> - XorTableEval<'a, E, ELEM_BITS, EXPAND_BITS> -{ - pub fn eval(mut self) -> E { - // al, bl are the constant columns for the inputs: All pairs of elements in [0, - // 2^LIMB_BITS). - // cl is the constant column for the xor: al ^ bl. - let [al] = self.eval.next_interaction_mask(2, [0]); - let [bl] = self.eval.next_interaction_mask(2, [0]); - let [cl] = self.eval.next_interaction_mask(2, [0]); - for i in 0..1 << EXPAND_BITS { - for j in 0..1 << EXPAND_BITS { - let multiplicity = self.eval.next_trace_mask(); - - let a = al - + E::F::from(BaseField::from_u32_unchecked( - i << limb_bits::(), - )); - let b = bl - + E::F::from(BaseField::from_u32_unchecked( - j << limb_bits::(), - )); - let c = cl - + E::F::from(BaseField::from_u32_unchecked( - (i ^ j) << limb_bits::(), - )); - - // Add with negative multiplicity. Consumers should lookup with positive - // multiplicity. - self.logup.push_lookup( - &mut self.eval, - (-multiplicity).into(), - &[a, b, c], - self.lookup_elements, - ); - } - } - self.logup.finalize(&mut self.eval); - self.eval - } -} diff --git a/Stwo_wrapper/crates/prover/src/examples/blake/xor_table/gen.rs b/Stwo_wrapper/crates/prover/src/examples/blake/xor_table/gen.rs deleted file mode 100644 index 195a6ca..0000000 --- a/Stwo_wrapper/crates/prover/src/examples/blake/xor_table/gen.rs +++ /dev/null @@ -1,168 +0,0 @@ -use std::simd::u32x16; - -use itertools::Itertools; -use tracing::{span, Level}; - -use super::{column_bits, limb_bits, XorAccumulator, XorElements}; -use crate::constraint_framework::logup::{LogupTraceGenerator, LookupElements}; -use crate::core::backend::simd::column::BaseColumn; -use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; -use crate::core::backend::simd::qm31::PackedSecureField; -use crate::core::backend::simd::SimdBackend; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; -use crate::core::poly::BitReversedOrder; -use crate::core::ColumnVec; - -pub struct XorTableLookupData { - pub xor_accum: XorAccumulator, -} - -pub fn generate_trace( - xor_accum: XorAccumulator, -) -> ( - ColumnVec>, - XorTableLookupData, -) { - ( - xor_accum - .mults - .iter() - .map(|mult| { - CircleEvaluation::new( - CanonicCoset::new(column_bits::()).circle_domain(), - mult.clone(), - ) - }) - .collect_vec(), - XorTableLookupData { xor_accum }, - ) -} - -/// Generates the interaction trace for the xor table. -/// Returns the interaction trace, the constant trace, and the claimed sum. -#[allow(clippy::type_complexity)] -pub fn generate_interaction_trace( - lookup_data: XorTableLookupData, - lookup_elements: &XorElements, -) -> ( - ColumnVec>, - SecureField, -) { - let limb_bits = limb_bits::(); - let _span = span!(Level::INFO, "Xor interaction trace").entered(); - let offsets_vec = u32x16::from_array(std::array::from_fn(|i| i as u32)); - let mut logup_gen = LogupTraceGenerator::new(column_bits::()); - - // Iterate each pair of columns, to batch their lookup together. - // There are 2^(2*EXPAND_BITS) column, for each combination of ah, bh. - let mut iter = lookup_data - .xor_accum - .mults - .iter() - .enumerate() - .array_chunks::<2>(); - for [(i0, mults0), (i1, mults1)] in &mut iter { - let mut col_gen = logup_gen.new_col(); - - // Extract ah, bh from column index. - let ah0 = i0 as u32 >> EXPAND_BITS; - let bh0 = i0 as u32 & ((1 << EXPAND_BITS) - 1); - let ah1 = i1 as u32 >> EXPAND_BITS; - let bh1 = i1 as u32 & ((1 << EXPAND_BITS) - 1); - - // Each column has 2^(2*LIMB_BITS) rows, packed in N_LANES. - #[allow(clippy::needless_range_loop)] - for vec_row in 0..(1 << (column_bits::() - LOG_N_LANES)) { - // vec_row is LIMB_BITS of al and LIMB_BITS - LOG_N_LANES of bl. - // Extract al, blh from vec_row. - let al = vec_row >> (limb_bits - LOG_N_LANES); - let blh = vec_row & ((1 << (limb_bits - LOG_N_LANES)) - 1); - - // Construct the 3 vectors a, b, c. - let a0 = u32x16::splat((ah0 << limb_bits) | al); - let a1 = u32x16::splat((ah1 << limb_bits) | al); - // bll is just the consecutive numbers 0 .. N_LANES-1. - let b0 = u32x16::splat((bh0 << limb_bits) | (blh << LOG_N_LANES)) | offsets_vec; - let b1 = u32x16::splat((bh1 << limb_bits) | (blh << LOG_N_LANES)) | offsets_vec; - - let c0 = a0 ^ b0; - let c1 = a1 ^ b1; - - let p0: PackedSecureField = lookup_elements - .combine(&[a0, b0, c0].map(|x| unsafe { PackedBaseField::from_simd_unchecked(x) })); - let p1: PackedSecureField = lookup_elements - .combine(&[a1, b1, c1].map(|x| unsafe { PackedBaseField::from_simd_unchecked(x) })); - - let num = p1 * mults0.data[vec_row as usize] + p0 * mults1.data[vec_row as usize]; - let denom = p0 * p1; - col_gen.write_frac(vec_row as usize, -num, denom); - } - col_gen.finalize_col(); - } - - // If there is an odd number of lookup expressions, handle the last one. - if let Some(rem) = iter.into_remainder() { - if let Some((i, mults)) = rem.collect_vec().pop() { - let mut col_gen = logup_gen.new_col(); - let ah = i as u32 >> EXPAND_BITS; - let bh = i as u32 & ((1 << EXPAND_BITS) - 1); - - #[allow(clippy::needless_range_loop)] - for vec_row in 0..(1 << (column_bits::() - LOG_N_LANES)) { - // vec_row is LIMB_BITS of a, and LIMB_BITS - LOG_N_LANES of b. - let al = vec_row >> (limb_bits - LOG_N_LANES); - let a = u32x16::splat((ah << limb_bits) | al); - let bm = vec_row & ((1 << (limb_bits - LOG_N_LANES)) - 1); - let b = u32x16::splat((bh << limb_bits) | (bm << LOG_N_LANES)) | offsets_vec; - - let c = a ^ b; - - let p: PackedSecureField = lookup_elements.combine( - &[a, b, c].map(|x| unsafe { PackedBaseField::from_simd_unchecked(x) }), - ); - - let num = mults.data[vec_row as usize]; - let denom = p; - col_gen.write_frac(vec_row as usize, PackedSecureField::from(-num), denom); - } - col_gen.finalize_col(); - } - } - - let (interaction_trace, claimed_sum) = logup_gen.finalize(); - (interaction_trace, claimed_sum) -} - -/// Generates the constant trace for the xor table. -/// Returns the constant trace, the constant trace, and the claimed sum. -#[allow(clippy::type_complexity)] -pub fn generate_constant_trace( -) -> ColumnVec> { - let limb_bits = limb_bits::(); - let _span = span!(Level::INFO, "Xor constant trace").entered(); - - // Generate the constant columns. In reality, these should be generated before the proof - // even began. - let a_col: BaseColumn = (0..(1 << (column_bits::()))) - .map(|i| BaseField::from_u32_unchecked((i >> limb_bits) as u32)) - .collect(); - let b_col: BaseColumn = (0..(1 << (column_bits::()))) - .map(|i| BaseField::from_u32_unchecked((i & ((1 << limb_bits) - 1)) as u32)) - .collect(); - let c_col: BaseColumn = (0..(1 << (column_bits::()))) - .map(|i| { - BaseField::from_u32_unchecked(((i >> limb_bits) ^ (i & ((1 << limb_bits) - 1))) as u32) - }) - .collect(); - - [a_col, b_col, c_col] - .map(|x| { - CircleEvaluation::new( - CanonicCoset::new(column_bits::()).circle_domain(), - x, - ) - }) - .to_vec() -} diff --git a/Stwo_wrapper/crates/prover/src/examples/blake/xor_table/mod.rs b/Stwo_wrapper/crates/prover/src/examples/blake/xor_table/mod.rs deleted file mode 100644 index 877a651..0000000 --- a/Stwo_wrapper/crates/prover/src/examples/blake/xor_table/mod.rs +++ /dev/null @@ -1,158 +0,0 @@ -#![allow(unused)] -//! Xor table component. -//! Generic on `ELEM_BITS` and `EXPAND_BITS`. -//! The table has all triplets of (a, b, a^b), where a, b are in the range [0,2^ELEM_BITS). -//! a,b are split into high and low parts, of size `EXPAND_BITS` and `ELEM_BITS - EXPAND_BITS` -//! respectively. -//! The component itself will hold 2^(2*EXPAND_BITS) multiplicity columns, each of size -//! 2^(ELEM_BITS - EXPAND_BITS). -//! The constant columns correspond only to the smaller table of the lower `ELEM_BITS - EXPAND_BITS` -//! xors: (a_l, b_l, a_l^b_l). -//! The rest of the lookups are computed based on these constant columns. - -mod constraints; -mod gen; - -use std::simd::u32x16; - -use itertools::Itertools; -use num_traits::Zero; -pub use r#gen::{generate_constant_trace, generate_interaction_trace, generate_trace}; - -use crate::constraint_framework::logup::{LogupAtRow, LookupElements}; -use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator}; -use crate::core::backend::simd::column::BaseColumn; -use crate::core::backend::Column; -use crate::core::fields::qm31::SecureField; -use crate::core::pcs::{TreeSubspan, TreeVec}; - -pub fn trace_sizes() -> TreeVec> { - let component = XorTableEval:: { - lookup_elements: LookupElements::<3>::dummy(), - claimed_sum: SecureField::zero(), - }; - let info = component.evaluate(InfoEvaluator::default()); - info.mask_offsets - .as_cols_ref() - .map_cols(|_| column_bits::()) -} - -const fn limb_bits() -> u32 { - ELEM_BITS - EXPAND_BITS -} -pub const fn column_bits() -> u32 { - 2 * limb_bits::() -} - -/// Accumulator that keeps track of the number of times each input has been used. -pub struct XorAccumulator { - /// 2^(2*EXPAND_BITS) multiplicity columns. Index (al, bl) of column (ah, bh) is the number of - /// times ah||al ^ bh||bl has been used. - pub mults: Vec, -} -impl Default - for XorAccumulator -{ - fn default() -> Self { - Self { - mults: (0..(1 << (2 * EXPAND_BITS))) - .map(|_| BaseColumn::zeros(1 << column_bits::())) - .collect_vec(), - } - } -} -impl XorAccumulator { - pub fn add_input(&mut self, a: u32x16, b: u32x16) { - // Split a and b into high and low parts, according to ELEMENT_BITS and EXPAND_BITS. - // The high part is the index of the multiplicity column. - // The low part is the index of the element in that column. - let al = a & u32x16::splat((1 << limb_bits::()) - 1); - let ah = a >> limb_bits::(); - let bl = b & u32x16::splat((1 << limb_bits::()) - 1); - let bh = b >> limb_bits::(); - let column_idx = (ah << EXPAND_BITS) + bh; - let offset = (al << limb_bits::()) + bl; - - // Since the indices may collide, we cannot use scatter simd operations here. - // Instead, loop over packed values. - for (column_idx, offset) in column_idx.as_array().iter().zip(offset.as_array().iter()) { - self.mults[*column_idx as usize].as_mut_slice()[*offset as usize].0 += 1; - } - } -} - -/// Component that evaluates the xor table. -pub type XorTableComponent = - FrameworkComponent>; - -pub type XorElements = LookupElements<3>; - -/// Evaluates the xor table. -pub struct XorTableEval { - pub lookup_elements: XorElements, - pub claimed_sum: SecureField, -} - -impl FrameworkEval - for XorTableEval -{ - fn log_size(&self) -> u32 { - column_bits::() - } - fn max_constraint_log_degree_bound(&self) -> u32 { - column_bits::() + 1 - } - fn evaluate(&self, mut eval: E) -> E { - let xor_eval = constraints::XorTableEval::<'_, _, ELEM_BITS, EXPAND_BITS> { - eval, - lookup_elements: &self.lookup_elements, - logup: LogupAtRow::new(1, self.claimed_sum, self.log_size()), - }; - xor_eval.eval() - } -} - -#[cfg(test)] -mod tests { - use std::simd::u32x16; - - use crate::constraint_framework::logup::LookupElements; - use crate::constraint_framework::{assert_constraints, FrameworkEval}; - use crate::core::poly::circle::CanonicCoset; - use crate::examples::blake::xor_table::r#gen::{ - generate_constant_trace, generate_interaction_trace, generate_trace, - }; - use crate::examples::blake::xor_table::{column_bits, XorAccumulator, XorTableEval}; - - #[test] - fn test_xor_table() { - use crate::core::pcs::TreeVec; - - const ELEM_BITS: u32 = 9; - const EXPAND_BITS: u32 = 2; - - let mut xor_accum = XorAccumulator::::default(); - xor_accum.add_input(u32x16::splat(1), u32x16::splat(2)); - - let (trace, lookup_data) = generate_trace(xor_accum); - let lookup_elements = crate::examples::blake::xor_table::XorElements::dummy(); - let (interaction_trace, claimed_sum) = - generate_interaction_trace(lookup_data, &lookup_elements); - let constant_trace = generate_constant_trace::(); - - let trace = TreeVec::new(vec![trace, interaction_trace, constant_trace]); - let trace_polys = trace.map_cols(|c| c.interpolate()); - - let component = XorTableEval:: { - lookup_elements, - claimed_sum, - }; - assert_constraints( - &trace_polys, - CanonicCoset::new(column_bits::()), - |eval| { - component.evaluate(eval); - }, - ) - } -} diff --git a/Stwo_wrapper/crates/prover/src/examples/mod.rs b/Stwo_wrapper/crates/prover/src/examples/mod.rs deleted file mode 100644 index 330662d..0000000 --- a/Stwo_wrapper/crates/prover/src/examples/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -pub mod blake; -pub mod plonk; -pub mod poseidon; -pub mod wide_fibonacci; -pub mod xor; diff --git a/Stwo_wrapper/crates/prover/src/examples/plonk/mod.rs b/Stwo_wrapper/crates/prover/src/examples/plonk/mod.rs deleted file mode 100644 index 58248a0..0000000 --- a/Stwo_wrapper/crates/prover/src/examples/plonk/mod.rs +++ /dev/null @@ -1,300 +0,0 @@ -use itertools::{chain, Itertools}; -use num_traits::One; -use tracing::{span, Level}; - -use crate::constraint_framework::logup::{LogupAtRow, LogupTraceGenerator, LookupElements}; -use crate::constraint_framework::{ - assert_constraints, EvalAtRow, FrameworkComponent, FrameworkEval, TraceLocationAllocator, -}; -use crate::core::backend::simd::column::BaseColumn; -use crate::core::backend::simd::m31::LOG_N_LANES; -use crate::core::backend::simd::qm31::PackedSecureField; -use crate::core::backend::simd::SimdBackend; -use crate::core::backend::Column; -use crate::core::channel::Blake2sChannel; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::pcs::{CommitmentSchemeProver, PcsConfig, TreeSubspan}; -use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; -use crate::core::poly::BitReversedOrder; -use crate::core::prover::{prove, StarkProof}; -use crate::core::vcs::blake2_merkle::{Blake2sMerkleChannel, Blake2sMerkleHasher}; -use crate::core::ColumnVec; - -pub type PlonkComponent = FrameworkComponent; - -#[derive(Clone)] -pub struct PlonkEval { - pub log_n_rows: u32, - pub lookup_elements: LookupElements<2>, - pub claimed_sum: SecureField, - pub base_trace_location: TreeSubspan, - pub interaction_trace_location: TreeSubspan, - pub constants_trace_location: TreeSubspan, -} - -impl FrameworkEval for PlonkEval { - fn log_size(&self) -> u32 { - self.log_n_rows - } - - fn max_constraint_log_degree_bound(&self) -> u32 { - self.log_n_rows + 1 - } - - fn evaluate(&self, mut eval: E) -> E { - let mut logup = LogupAtRow::<2, _>::new(1, self.claimed_sum, self.log_n_rows); - - let [a_wire] = eval.next_interaction_mask(2, [0]); - let [b_wire] = eval.next_interaction_mask(2, [0]); - // Note: c_wire could also be implicit: (self.eval.point() - M31_CIRCLE_GEN.into_ef()).x. - // A constant column is easier though. - let [c_wire] = eval.next_interaction_mask(2, [0]); - let [op] = eval.next_interaction_mask(2, [0]); - - let mult = eval.next_trace_mask(); - let a_val = eval.next_trace_mask(); - let b_val = eval.next_trace_mask(); - let c_val = eval.next_trace_mask(); - - eval.add_constraint(c_val - op * (a_val + b_val) + (E::F::one() - op) * a_val * b_val); - - logup.push_lookup( - &mut eval, - E::EF::one(), - &[a_wire, a_val], - &self.lookup_elements, - ); - logup.push_lookup( - &mut eval, - E::EF::one(), - &[b_wire, b_val], - &self.lookup_elements, - ); - logup.push_lookup( - &mut eval, - E::EF::from(-mult), - &[c_wire, c_val], - &self.lookup_elements, - ); - - logup.finalize(&mut eval); - eval - } -} - -#[derive(Clone)] -pub struct PlonkCircuitTrace { - pub mult: BaseColumn, - pub a_wire: BaseColumn, - pub b_wire: BaseColumn, - pub c_wire: BaseColumn, - pub op: BaseColumn, - pub a_val: BaseColumn, - pub b_val: BaseColumn, - pub c_val: BaseColumn, -} -pub fn gen_trace( - log_size: u32, - circuit: &PlonkCircuitTrace, -) -> ColumnVec> { - let _span = span!(Level::INFO, "Generation").entered(); - - let domain = CanonicCoset::new(log_size).circle_domain(); - [ - &circuit.mult, - &circuit.a_val, - &circuit.b_val, - &circuit.c_val, - ] - .into_iter() - .map(|eval| CircleEvaluation::::new(domain, eval.clone())) - .collect_vec() -} - -pub fn gen_interaction_trace( - log_size: u32, - circuit: &PlonkCircuitTrace, - lookup_elements: &LookupElements<2>, -) -> ( - ColumnVec>, - SecureField, -) { - let _span = span!(Level::INFO, "Generate interaction trace").entered(); - let mut logup_gen = LogupTraceGenerator::new(log_size); - - let mut col_gen = logup_gen.new_col(); - for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { - let q0: PackedSecureField = - lookup_elements.combine(&[circuit.a_wire.data[vec_row], circuit.a_val.data[vec_row]]); - let q1: PackedSecureField = - lookup_elements.combine(&[circuit.b_wire.data[vec_row], circuit.b_val.data[vec_row]]); - col_gen.write_frac(vec_row, q0 + q1, q0 * q1); - } - col_gen.finalize_col(); - - let mut col_gen = logup_gen.new_col(); - for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { - let p = -circuit.mult.data[vec_row]; - let q: PackedSecureField = - lookup_elements.combine(&[circuit.c_wire.data[vec_row], circuit.c_val.data[vec_row]]); - col_gen.write_frac(vec_row, p.into(), q); - } - col_gen.finalize_col(); - - logup_gen.finalize() -} - -#[allow(unused)] -pub fn prove_fibonacci_plonk( - log_n_rows: u32, - config: PcsConfig, -) -> (PlonkComponent, StarkProof) { - assert!(log_n_rows >= LOG_N_LANES); - - // Prepare a fibonacci circuit. - let mut fib_values = vec![BaseField::one(), BaseField::one()]; - for _ in 0..(1 << log_n_rows) { - fib_values.push(fib_values[fib_values.len() - 1] + fib_values[fib_values.len() - 2]); - } - let range = 0..(1 << log_n_rows); - let mut circuit = PlonkCircuitTrace { - mult: range.clone().map(|_| 2.into()).collect(), - a_wire: range.clone().map(|i| i.into()).collect(), - b_wire: range.clone().map(|i| (i + 1).into()).collect(), - c_wire: range.clone().map(|i| (i + 2).into()).collect(), - op: range.clone().map(|_| 1.into()).collect(), - a_val: range.clone().map(|i| fib_values[i]).collect(), - b_val: range.clone().map(|i| fib_values[i + 1]).collect(), - c_val: range.clone().map(|i| fib_values[i + 2]).collect(), - }; - circuit.mult.set((1 << log_n_rows) - 1, 0.into()); - circuit.mult.set((1 << log_n_rows) - 2, 1.into()); - - // Precompute twiddles. - let span = span!(Level::INFO, "Precompute twiddles").entered(); - let twiddles = SimdBackend::precompute_twiddles( - CanonicCoset::new(log_n_rows + config.fri_config.log_blowup_factor + 1) - .circle_domain() - .half_coset, - ); - span.exit(); - - // Setup protocol. - let channel = &mut Blake2sChannel::default(); - let commitment_scheme = - &mut CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles); - - // Trace. - let span = span!(Level::INFO, "Trace").entered(); - let trace = gen_trace(log_n_rows, &circuit); - let mut tree_builder = commitment_scheme.tree_builder(); - let base_trace_location = tree_builder.extend_evals(trace); - tree_builder.commit(channel); - span.exit(); - - // Draw lookup element. - let lookup_elements = LookupElements::draw(channel); - - // Interaction trace. - let span = span!(Level::INFO, "Interaction").entered(); - let (trace, claimed_sum) = gen_interaction_trace(log_n_rows, &circuit, &lookup_elements); - let mut tree_builder = commitment_scheme.tree_builder(); - let interaction_trace_location = tree_builder.extend_evals(trace); - tree_builder.commit(channel); - span.exit(); - - // Constant trace. - let span = span!(Level::INFO, "Constant").entered(); - let mut tree_builder = commitment_scheme.tree_builder(); - let constants_trace_location = tree_builder.extend_evals( - chain!([circuit.a_wire, circuit.b_wire, circuit.c_wire, circuit.op] - .into_iter() - .map(|col| { - CircleEvaluation::::new( - CanonicCoset::new(log_n_rows).circle_domain(), - col, - ) - })) - .collect_vec(), - ); - tree_builder.commit(channel); - span.exit(); - - // Prove constraints. - let component = PlonkComponent::new( - &mut TraceLocationAllocator::default(), - PlonkEval { - log_n_rows, - lookup_elements, - claimed_sum, - base_trace_location, - interaction_trace_location, - constants_trace_location, - }, - ); - - // Sanity check. Remove for production. - let trace_polys = commitment_scheme - .trees - .as_ref() - .map(|t| t.polynomials.iter().cloned().collect_vec()); - assert_constraints(&trace_polys, CanonicCoset::new(log_n_rows), |mut eval| { - component.evaluate(eval); - }); - - let proof = prove::(&[&component], channel, commitment_scheme).unwrap(); - - (component, proof) -} - -#[cfg(test)] -mod tests { - use std::env; - - use crate::constraint_framework::logup::LookupElements; - use crate::core::air::Component; - use crate::core::channel::Blake2sChannel; - use crate::core::fri::FriConfig; - use crate::core::pcs::{CommitmentSchemeVerifier, PcsConfig}; - use crate::core::prover::verify; - use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; - use crate::examples::plonk::prove_fibonacci_plonk; - - #[test_log::test] - fn test_simd_plonk_prove() { - // Get from environment variable: - let log_n_instances = env::var("LOG_N_INSTANCES") - .unwrap_or_else(|_| "10".to_string()) - .parse::() - .unwrap(); - let config = PcsConfig { - pow_bits: 10, - fri_config: FriConfig::new(5, 4, 64), - }; - - // Prove. - let (component, proof) = prove_fibonacci_plonk(log_n_instances, config); - - // Verify. - // TODO: Create Air instance independently. - let channel = &mut Blake2sChannel::default(); - let commitment_scheme = &mut CommitmentSchemeVerifier::::new(config); - - // Decommit. - // Retrieve the expected column sizes in each commitment interaction, from the AIR. - let sizes = component.trace_log_degree_bounds(); - // Trace columns. - commitment_scheme.commit(proof.commitments[0], &sizes[0], channel); - // Draw lookup element. - let lookup_elements = LookupElements::<2>::draw(channel); - assert_eq!(lookup_elements, component.lookup_elements); - // TODO(spapini): Check claimed sum against first and last instances. - // Interaction columns. - commitment_scheme.commit(proof.commitments[1], &sizes[1], channel); - // Constant columns. - commitment_scheme.commit(proof.commitments[2], &sizes[2], channel); - - verify(&[&component], channel, commitment_scheme, proof).unwrap(); - } -} diff --git a/Stwo_wrapper/crates/prover/src/examples/poseidon/mod.rs b/Stwo_wrapper/crates/prover/src/examples/poseidon/mod.rs deleted file mode 100644 index c94f0ba..0000000 --- a/Stwo_wrapper/crates/prover/src/examples/poseidon/mod.rs +++ /dev/null @@ -1,508 +0,0 @@ -//! AIR for Poseidon2 hash function from . - -use std::ops::{Add, AddAssign, Mul, Sub}; - -use itertools::Itertools; -use num_traits::One; -use tracing::{span, Level}; - -use crate::constraint_framework::logup::{LogupAtRow, LogupTraceGenerator, LookupElements}; -use crate::constraint_framework::{ - EvalAtRow, FrameworkComponent, FrameworkEval, TraceLocationAllocator, -}; -use crate::core::backend::simd::column::BaseColumn; -use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; -use crate::core::backend::simd::qm31::PackedSecureField; -use crate::core::backend::simd::SimdBackend; -use crate::core::backend::{Col, Column}; -use crate::core::channel::Blake2sChannel; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::FieldExpOps; -use crate::core::pcs::{CommitmentSchemeProver, PcsConfig}; -use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; -use crate::core::poly::BitReversedOrder; -use crate::core::prover::{prove, StarkProof}; -use crate::core::vcs::blake2_merkle::{Blake2sMerkleChannel, Blake2sMerkleHasher}; -use crate::core::ColumnVec; - -const N_LOG_INSTANCES_PER_ROW: usize = 3; -const N_INSTANCES_PER_ROW: usize = 1 << N_LOG_INSTANCES_PER_ROW; -const N_STATE: usize = 16; -const N_PARTIAL_ROUNDS: usize = 14; -const N_HALF_FULL_ROUNDS: usize = 4; -const FULL_ROUNDS: usize = 2 * N_HALF_FULL_ROUNDS; -const N_COLUMNS_PER_REP: usize = N_STATE * (1 + FULL_ROUNDS) + N_PARTIAL_ROUNDS; -const N_COLUMNS: usize = N_INSTANCES_PER_ROW * N_COLUMNS_PER_REP; -const LOG_EXPAND: u32 = 2; -// TODO(spapini): Pick better constants. -const EXTERNAL_ROUND_CONSTS: [[BaseField; N_STATE]; 2 * N_HALF_FULL_ROUNDS] = - [[BaseField::from_u32_unchecked(1234); N_STATE]; 2 * N_HALF_FULL_ROUNDS]; -const INTERNAL_ROUND_CONSTS: [BaseField; N_PARTIAL_ROUNDS] = - [BaseField::from_u32_unchecked(1234); N_PARTIAL_ROUNDS]; - -pub type PoseidonComponent = FrameworkComponent; - -pub type PoseidonElements = LookupElements<{ N_STATE * 2 }>; - -#[derive(Clone)] -pub struct PoseidonEval { - pub log_n_rows: u32, - pub lookup_elements: PoseidonElements, - pub claimed_sum: SecureField, -} -impl FrameworkEval for PoseidonEval { - fn log_size(&self) -> u32 { - self.log_n_rows - } - fn max_constraint_log_degree_bound(&self) -> u32 { - self.log_n_rows + LOG_EXPAND - } - fn evaluate(&self, mut eval: E) -> E { - let logup = LogupAtRow::new(1, self.claimed_sum, self.log_n_rows); - eval_poseidon_constraints(&mut eval, logup, &self.lookup_elements); - eval - } -} - -#[inline(always)] -/// Applies the M4 MDS matrix described in 5.1. -fn apply_m4(x: [F; 4]) -> [F; 4] -where - F: Copy + AddAssign + Add + Sub + Mul, -{ - let t0 = x[0] + x[1]; - let t02 = t0 + t0; - let t1 = x[2] + x[3]; - let t12 = t1 + t1; - let t2 = x[1] + x[1] + t1; - let t3 = x[3] + x[3] + t0; - let t4 = t12 + t12 + t3; - let t5 = t02 + t02 + t2; - let t6 = t3 + t5; - let t7 = t2 + t4; - [t6, t5, t7, t4] -} - -/// Applies the external round matrix. -/// See 5.1 and Appendix B. -fn apply_external_round_matrix(state: &mut [F; 16]) -where - F: Copy + AddAssign + Add + Sub + Mul, -{ - // Applies circ(2M4, M4, M4, M4). - for i in 0..4 { - [ - state[4 * i], - state[4 * i + 1], - state[4 * i + 2], - state[4 * i + 3], - ] = apply_m4([ - state[4 * i], - state[4 * i + 1], - state[4 * i + 2], - state[4 * i + 3], - ]); - } - for j in 0..4 { - let s = state[j] + state[j + 4] + state[j + 8] + state[j + 12]; - for i in 0..4 { - state[4 * i + j] += s; - } - } -} - -// Applies the internal round matrix. -// mu_i = 2^{i+1} + 1. -// See 5.2. -fn apply_internal_round_matrix(state: &mut [F; 16]) -where - F: Copy + AddAssign + Add + Sub + Mul, -{ - // TODO(spapini): Check that these coefficients are good according to section 5.3 of Poseidon2 - // paper. - let sum = state[1..].iter().fold(state[0], |acc, s| acc + *s); - state.iter_mut().enumerate().for_each(|(i, s)| { - // TODO(spapini): Change to rotations. - *s = *s * BaseField::from_u32_unchecked(1 << (i + 1)) + sum; - }); -} - -fn pow5(x: F) -> F { - let x2 = x * x; - let x4 = x2 * x2; - x4 * x -} - -pub fn eval_poseidon_constraints( - eval: &mut E, - mut logup: LogupAtRow<2, E>, - lookup_elements: &PoseidonElements, -) { - for _ in 0..N_INSTANCES_PER_ROW { - let mut state: [_; N_STATE] = std::array::from_fn(|_| eval.next_trace_mask()); - - // Require state lookup. - logup.push_lookup(eval, E::EF::one(), &state, lookup_elements); - - // 4 full rounds. - (0..N_HALF_FULL_ROUNDS).for_each(|round| { - (0..N_STATE).for_each(|i| { - state[i] += EXTERNAL_ROUND_CONSTS[round][i]; - }); - apply_external_round_matrix(&mut state); - state = std::array::from_fn(|i| pow5(state[i])); - state.iter_mut().for_each(|s| { - let m = eval.next_trace_mask(); - eval.add_constraint(*s - m); - *s = m; - }); - }); - - // Partial rounds. - (0..N_PARTIAL_ROUNDS).for_each(|round| { - state[0] += INTERNAL_ROUND_CONSTS[round]; - apply_internal_round_matrix(&mut state); - state[0] = pow5(state[0]); - let m = eval.next_trace_mask(); - eval.add_constraint(state[0] - m); - state[0] = m; - }); - - // 4 full rounds. - (0..N_HALF_FULL_ROUNDS).for_each(|round| { - (0..N_STATE).for_each(|i| { - state[i] += EXTERNAL_ROUND_CONSTS[round + N_HALF_FULL_ROUNDS][i]; - }); - apply_external_round_matrix(&mut state); - state = std::array::from_fn(|i| pow5(state[i])); - state.iter_mut().for_each(|s| { - let m = eval.next_trace_mask(); - eval.add_constraint(*s - m); - *s = m; - }); - }); - - // Provide state lookup. - logup.push_lookup(eval, -E::EF::one(), &state, lookup_elements); - } - - logup.finalize(eval); -} - -pub struct LookupData { - initial_state: [[BaseColumn; N_STATE]; N_INSTANCES_PER_ROW], - final_state: [[BaseColumn; N_STATE]; N_INSTANCES_PER_ROW], -} -pub fn gen_trace( - log_size: u32, -) -> ( - ColumnVec>, - LookupData, -) { - let _span = span!(Level::INFO, "Generation").entered(); - assert!(log_size >= LOG_N_LANES); - let mut trace = (0..N_COLUMNS) - .map(|_| Col::::zeros(1 << log_size)) - .collect_vec(); - let mut lookup_data = LookupData { - initial_state: std::array::from_fn(|_| { - std::array::from_fn(|_| BaseColumn::zeros(1 << log_size)) - }), - final_state: std::array::from_fn(|_| { - std::array::from_fn(|_| BaseColumn::zeros(1 << log_size)) - }), - }; - - for vec_index in 0..(1 << (log_size - LOG_N_LANES)) { - // Initial state. - let mut col_index = 0; - for rep_i in 0..N_INSTANCES_PER_ROW { - let mut state: [_; N_STATE] = std::array::from_fn(|state_i| { - PackedBaseField::from_array(std::array::from_fn(|i| { - BaseField::from_u32_unchecked((vec_index * 16 + i + state_i + rep_i) as u32) - })) - }); - state.iter().copied().for_each(|s| { - trace[col_index].data[vec_index] = s; - col_index += 1; - }); - lookup_data.initial_state[rep_i] - .iter_mut() - .zip(state) - .for_each(|(res, state_i)| res.data[vec_index] = state_i); - - // 4 full rounds. - (0..N_HALF_FULL_ROUNDS).for_each(|round| { - (0..N_STATE).for_each(|i| { - state[i] += PackedBaseField::broadcast(EXTERNAL_ROUND_CONSTS[round][i]); - }); - apply_external_round_matrix(&mut state); - state = std::array::from_fn(|i| pow5(state[i])); - state.iter().copied().for_each(|s| { - trace[col_index].data[vec_index] = s; - col_index += 1; - }); - }); - - // Partial rounds. - (0..N_PARTIAL_ROUNDS).for_each(|round| { - state[0] += PackedBaseField::broadcast(INTERNAL_ROUND_CONSTS[round]); - apply_internal_round_matrix(&mut state); - state[0] = pow5(state[0]); - trace[col_index].data[vec_index] = state[0]; - col_index += 1; - }); - - // 4 full rounds. - (0..N_HALF_FULL_ROUNDS).for_each(|round| { - (0..N_STATE).for_each(|i| { - state[i] += PackedBaseField::broadcast( - EXTERNAL_ROUND_CONSTS[round + N_HALF_FULL_ROUNDS][i], - ); - }); - apply_external_round_matrix(&mut state); - state = std::array::from_fn(|i| pow5(state[i])); - state.iter().copied().for_each(|s| { - trace[col_index].data[vec_index] = s; - col_index += 1; - }); - }); - - lookup_data.final_state[rep_i] - .iter_mut() - .zip(state) - .for_each(|(res, state_i)| res.data[vec_index] = state_i); - } - } - let domain = CanonicCoset::new(log_size).circle_domain(); - let trace = trace - .into_iter() - .map(|eval| CircleEvaluation::::new(domain, eval)) - .collect_vec(); - (trace, lookup_data) -} - -pub fn gen_interaction_trace( - log_size: u32, - lookup_data: LookupData, - lookup_elements: &PoseidonElements, -) -> ( - ColumnVec>, - SecureField, -) { - let _span = span!(Level::INFO, "Generate interaction trace").entered(); - let mut logup_gen = LogupTraceGenerator::new(log_size); - - #[allow(clippy::needless_range_loop)] - for rep_i in 0..N_INSTANCES_PER_ROW { - let mut col_gen = logup_gen.new_col(); - for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { - // Batch the 2 lookups together. - let denom0: PackedSecureField = lookup_elements.combine( - &lookup_data.initial_state[rep_i] - .each_ref() - .map(|s| s.data[vec_row]), - ); - let denom1: PackedSecureField = lookup_elements.combine( - &lookup_data.final_state[rep_i] - .each_ref() - .map(|s| s.data[vec_row]), - ); - // (1 / denom1) - (1 / denom1) = (denom1 - denom0) / (denom0 * denom1). - col_gen.write_frac(vec_row, denom1 - denom0, denom0 * denom1); - } - col_gen.finalize_col(); - } - - logup_gen.finalize() -} - -pub fn prove_poseidon( - log_n_instances: u32, - config: PcsConfig, -) -> (PoseidonComponent, StarkProof) { - assert!(log_n_instances >= N_LOG_INSTANCES_PER_ROW as u32); - let log_n_rows = log_n_instances - N_LOG_INSTANCES_PER_ROW as u32; - - // Precompute twiddles. - let span = span!(Level::INFO, "Precompute twiddles").entered(); - let twiddles = SimdBackend::precompute_twiddles( - CanonicCoset::new(log_n_rows + LOG_EXPAND + config.fri_config.log_blowup_factor) - .circle_domain() - .half_coset, - ); - span.exit(); - - // Setup protocol. - let channel = &mut Blake2sChannel::default(); - let commitment_scheme = - &mut CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles); - - // Trace. - let span = span!(Level::INFO, "Trace").entered(); - let (trace, lookup_data) = gen_trace(log_n_rows); - let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals(trace); - tree_builder.commit(channel); - span.exit(); - - // Draw lookup elements. - let lookup_elements = PoseidonElements::draw(channel); - - // Interaction trace. - let span = span!(Level::INFO, "Interaction").entered(); - let (trace, claimed_sum) = gen_interaction_trace(log_n_rows, lookup_data, &lookup_elements); - let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals(trace); - tree_builder.commit(channel); - span.exit(); - - // Prove constraints. - let component = PoseidonComponent::new( - &mut TraceLocationAllocator::default(), - PoseidonEval { - log_n_rows, - lookup_elements, - claimed_sum, - }, - ); - let proof = prove::(&[&component], channel, commitment_scheme).unwrap(); - (component, proof) -} - -#[cfg(test)] -mod tests { - use std::env; - - use itertools::Itertools; - use num_traits::One; - - use crate::constraint_framework::assert_constraints; - use crate::constraint_framework::logup::{LogupAtRow, LookupElements}; - use crate::core::air::Component; - use crate::core::channel::Blake2sChannel; - use crate::core::fields::m31::BaseField; - use crate::core::fri::FriConfig; - use crate::core::pcs::{CommitmentSchemeVerifier, PcsConfig, TreeVec}; - use crate::core::poly::circle::CanonicCoset; - use crate::core::prover::verify; - use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; - use crate::examples::poseidon::{ - apply_internal_round_matrix, apply_m4, eval_poseidon_constraints, gen_interaction_trace, - gen_trace, prove_poseidon, PoseidonElements, - }; - use crate::math::matrix::{RowMajorMatrix, SquareMatrix}; - - #[cfg(all(target_family = "wasm", not(target_os = "wasi")))] - #[wasm_bindgen_test::wasm_bindgen_test] - fn test_poseidon_prove_wasm() { - const LOG_N_INSTANCES: u32 = 10; - let config = PcsConfig { - pow_bits: 10, - fri_config: FriConfig::new(5, 1, 64), - }; - - // Prove. - prove_poseidon(LOG_N_INSTANCES, config); - } - - #[test] - fn test_apply_m4() { - let m4 = RowMajorMatrix::::new( - [5, 7, 1, 3, 4, 6, 1, 1, 1, 3, 5, 7, 1, 1, 4, 6] - .map(BaseField::from_u32_unchecked) - .into_iter() - .collect_vec(), - ); - let state = (0..4) - .map(BaseField::from_u32_unchecked) - .collect_vec() - .try_into() - .unwrap(); - - assert_eq!(apply_m4(state), m4.mul(state)); - } - - #[test] - fn test_apply_internal() { - let mut state: [BaseField; 16] = (0..16) - .map(|i| BaseField::from_u32_unchecked(i * 3 + 187)) - .collect_vec() - .try_into() - .unwrap(); - let mut internal_matrix = [[BaseField::one(); 16]; 16]; - #[allow(clippy::needless_range_loop)] - for i in 0..16 { - internal_matrix[i][i] += BaseField::from_u32_unchecked(1 << (i + 1)); - } - let matrix = RowMajorMatrix::::new(internal_matrix.flatten().to_vec()); - - let expected_state = matrix.mul(state); - apply_internal_round_matrix(&mut state); - - assert_eq!(state, expected_state); - } - - #[test] - fn test_poseidon_constraints() { - const LOG_N_ROWS: u32 = 8; - - // Trace. - let (trace0, interaction_data) = gen_trace(LOG_N_ROWS); - let lookup_elements = LookupElements::dummy(); - let (trace1, claimed_sum) = - gen_interaction_trace(LOG_N_ROWS, interaction_data, &lookup_elements); - - let traces = TreeVec::new(vec![trace0, trace1]); - let trace_polys = - traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect_vec()); - assert_constraints(&trace_polys, CanonicCoset::new(LOG_N_ROWS), |mut eval| { - eval_poseidon_constraints( - &mut eval, - LogupAtRow::new(1, claimed_sum, LOG_N_ROWS), - &lookup_elements, - ); - }); - } - - #[test_log::test] - fn test_simd_poseidon_prove() { - // Note: To see time measurement, run test with - // RUST_LOG_SPAN_EVENTS=enter,close RUST_LOG=info RUST_BACKTRACE=1 RUSTFLAGS=" - // -C target-cpu=native -C target-feature=+avx512f -C opt-level=3" cargo test - // test_simd_poseidon_prove -- --nocapture - - // Get from environment variable: - let log_n_instances = env::var("LOG_N_INSTANCES") - .unwrap_or_else(|_| "10".to_string()) - .parse::() - .unwrap(); - let config = PcsConfig { - pow_bits: 10, - fri_config: FriConfig::new(5, 1, 64), - }; - - // Prove. - let (component, proof) = prove_poseidon(log_n_instances, config); - - // Verify. - // TODO: Create Air instance independently. - let channel = &mut Blake2sChannel::default(); - let commitment_scheme = &mut CommitmentSchemeVerifier::::new(config); - - // Decommit. - // Retrieve the expected column sizes in each commitment interaction, from the AIR. - let sizes = component.trace_log_degree_bounds(); - // Trace columns. - commitment_scheme.commit(proof.commitments[0], &sizes[0], channel); - // Draw lookup element. - let lookup_elements = PoseidonElements::draw(channel); - assert_eq!(lookup_elements, component.lookup_elements); - // TODO(spapini): Check claimed sum against first and last instances. - // Interaction columns. - commitment_scheme.commit(proof.commitments[1], &sizes[1], channel); - - verify(&[&component], channel, commitment_scheme, proof).unwrap(); - } -} diff --git a/Stwo_wrapper/crates/prover/src/examples/wide_fibonacci/mod.rs b/Stwo_wrapper/crates/prover/src/examples/wide_fibonacci/mod.rs deleted file mode 100644 index 1ac82d0..0000000 --- a/Stwo_wrapper/crates/prover/src/examples/wide_fibonacci/mod.rs +++ /dev/null @@ -1,619 +0,0 @@ -use itertools::Itertools; -use std::fs::File; -use std::io::{Error, Write}; -use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval}; -use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; -use crate::core::backend::simd::SimdBackend; -use crate::core::backend::{Col, Column}; -use crate::core::fields::m31::BaseField; -use crate::core::fields::FieldExpOps; -use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; -use crate::core::poly::BitReversedOrder; -use crate::core::ColumnVec; -use crate::core::fields::qm31::SecureField; -use crate::core::prover::StarkProof; -//use crate::core::vcs::hash::Hash; -use crate::core::vcs::poseidon_bls_merkle::PoseidonBLSMerkleHasher; - -pub type WideFibonacciComponent = FrameworkComponent>; - -pub struct FibInput { - a: PackedBaseField, - b: PackedBaseField, -} - -/// A component that enforces the Fibonacci sequence. -/// Each row contains a seperate Fibonacci sequence of length `N`. -#[derive(Clone)] -pub struct WideFibonacciEval { - pub log_n_rows: u32, -} -impl FrameworkEval for WideFibonacciEval { - fn log_size(&self) -> u32 { - self.log_n_rows - } - fn max_constraint_log_degree_bound(&self) -> u32 { - self.log_n_rows + 1 - } - fn evaluate(&self, mut eval: E) -> E { - let mut a = eval.next_trace_mask(); - let mut b = eval.next_trace_mask(); - for _ in 2..N { - let c = eval.next_trace_mask(); - eval.add_constraint(c - (a.square() + b.square())); - a = b; - b = c; - } - eval - } -} - -pub fn generate_trace( - log_size: u32, - inputs: &[FibInput], -) -> ColumnVec> { - assert!(log_size >= LOG_N_LANES); - assert_eq!(inputs.len(), 1 << (log_size - LOG_N_LANES)); - let mut trace = (0..N) - .map(|_| Col::::zeros(1 << log_size)) - .collect_vec(); - for (vec_index, input) in inputs.iter().enumerate() { - let mut a = input.a; - let mut b = input.b; - trace[0].data[vec_index] = a; - trace[1].data[vec_index] = b; - trace.iter_mut().skip(2).for_each(|col| { - (a, b) = (b, a.square() + b.square()); - col.data[vec_index] = b; - }); - } - let domain = CanonicCoset::new(log_size).circle_domain(); - trace - .into_iter() - .map(|eval| CircleEvaluation::::new(domain, eval)) - .collect_vec() -} - -pub fn save_secure_field_element(file: &mut File, q31_element: SecureField) -> Result<(), Error> { - file.write_all(b"[\"")?; - file.write_all(&q31_element.0.0.to_string().into_bytes())?; - file.write_all(b"\",\"")?; - file.write_all(&q31_element.0.1.to_string().into_bytes())?; - file.write_all(b"\",\"")?; - file.write_all(&q31_element.1.0.to_string().into_bytes())?; - file.write_all(b"\",\"")?; - file.write_all(&q31_element.1.1.to_string().into_bytes())?; - file.write_all(b"\"]")?; - Ok(()) -} - -pub fn pretty_save_poseidon_bls_proof(proof: &StarkProof) -> Result<(),Error> { - let mut file = File::create("proof.json")?; - file.write_all(b"{\n")?; - //commitments - file.write_all(b"\t\"commitments\" : \n\t\t[")?; - for i in 0..proof.commitments.len() { - file.write_all(b"\"")?; - file.write_all(&proof.commitments[i].to_string().into_bytes())?; - file.write_all(b"\"")?; - if proof.commitments.len() != i+1 { - file.write_all(b",\n\t\t")?; - } - } - file.write_all(b"],\n\n")?; - // Sampled Values - file.write_all(b"\t\"sampled_values_0\" : \n\t\t[")?; - for i in 0..proof.commitment_scheme_proof.sampled_values.0[0].len() { - save_secure_field_element(&mut file,proof.commitment_scheme_proof.sampled_values.0[0][i][0])?; - if proof.commitment_scheme_proof.sampled_values.0[0].len() != i+1 { - file.write_all(b",\n\t\t")?; - } - } - file.write_all(b"],\n\n")?; - file.write_all(b"\t\"sampled_values_1\" : \n\t\t[")?; - for i in 0..proof.commitment_scheme_proof.sampled_values.0[1].len() { - save_secure_field_element(&mut file,proof.commitment_scheme_proof.sampled_values.0[1][i][0])?; - if proof.commitment_scheme_proof.sampled_values.0[1].len() != i+1 { - file.write_all(b",\n\t\t")?; - } - } - file.write_all(b"],\n\n")?; - - //decommitments - file.write_all(b"\t\"decommitment_0\" : \n\t\t[")?; - for i in 0..proof.commitment_scheme_proof.decommitments.0[0].hash_witness.len() { - file.write_all(b"\"")?; - file.write_all(&proof.commitment_scheme_proof.decommitments.0[0].hash_witness[i].to_string().into_bytes())?; - file.write_all(b"\"")?; - if proof.commitment_scheme_proof.decommitments.0[0].hash_witness.len() != i+1 { - file.write_all(b",\n\t\t")?; - } - } - file.write_all(b"],\n\n")?; - file.write_all(b"\t\"decommitment_1\" : \n\t\t[")?; - for i in 0..proof.commitment_scheme_proof.decommitments.0[1].hash_witness.len() { - file.write_all(b"\"")?; - file.write_all(&proof.commitment_scheme_proof.decommitments.0[1].hash_witness[i].to_string().into_bytes())?; - file.write_all(b"\"")?; - if proof.commitment_scheme_proof.decommitments.0[1].hash_witness.len() != i+1 { - file.write_all(b",\n\t\t")?; - } - } - file.write_all(b"],\n\n")?; - - //Queried_values - file.write_all(b"\t\"queried_values_0\" : \n\t\t[")?; - for i in 0..proof.commitment_scheme_proof.queried_values.0[0].len() { - file.write_all(b"[")?; - for j in 0..6 { - file.write_all(b"\"")?; - file.write_all(&proof.commitment_scheme_proof.queried_values.0[0][i][j].to_string().into_bytes())?; - file.write_all(b"\"")?; - if j != 5 { - file.write_all(b",")?; - } else { file.write_all(b"]")?; } - } - if proof.commitment_scheme_proof.queried_values.0[0].len() != i+1 { - file.write_all(b",\n\t\t")?; - } - } - file.write_all(b"],\n\n")?; - file.write_all(b"\t\"queried_values_1\" : [")?; - for i in 0..proof.commitment_scheme_proof.queried_values.0[1].len() { - file.write_all(b"[")?; - for j in 0..6 { - file.write_all(b"\"")?; - file.write_all(&proof.commitment_scheme_proof.queried_values.0[1][i][j].to_string().into_bytes())?; - file.write_all(b"\"")?; - if j != 5 { - file.write_all(b",")?; - } else { file.write_all(b"]")?; } - } - if proof.commitment_scheme_proof.queried_values.0[1].len() != i+1 { - file.write_all(b",\n\t\t")?; - } - } - file.write_all(b"],\n\n")?; - - //proof of work - file.write_all(b"\t\"proof of work\" : \"")?; - file.write_all(&proof.commitment_scheme_proof.proof_of_work.to_string().into_bytes())?; - file.write_all(b"\",\n\n")?; - - //last FRI layer coeffs - file.write_all(b"\t\"coeffs\" : ")?; - save_secure_field_element(&mut file,proof.commitment_scheme_proof.fri_proof.last_layer_poly.coeffs[0] )?; - file.write_all(b",\n\n")?; - - //intermediate FRI layers - for i in 0..6 { - //commitment - file.write_all(b"\t\"inner_commitment_")?; - file.write_all(&i.to_string().into_bytes())?; - file.write_all(b"\" : \"")?; - file.write_all(&proof.commitment_scheme_proof.fri_proof.inner_layers[i].commitment.to_string().into_bytes())?; - file.write_all(b"\",\n\n")?; - - //decommitment - file.write_all(b"\t\"inner_decommitment_")?; - file.write_all(&i.to_string().into_bytes())?; - file.write_all(b"\" : \n\t\t[")?; - for j in 0..proof.commitment_scheme_proof.fri_proof.inner_layers[i].decommitment.hash_witness.len() { - file.write_all(b"\"")?; - file.write_all(&proof.commitment_scheme_proof.fri_proof.inner_layers[i].decommitment.hash_witness[j].to_string().into_bytes())?; - file.write_all(b"\"")?; - if proof.commitment_scheme_proof.fri_proof.inner_layers[i].decommitment.hash_witness.len() != j+1 { - file.write_all(b",\n\t\t")?; - } - } - file.write_all(b"],\n\n")?; - - //evals_subset - file.write_all(b"\t\"inner_evals_subset_")?; - file.write_all(&i.to_string().into_bytes())?; - file.write_all(b"\" : \n\t\t[")?; - for j in 0..proof.commitment_scheme_proof.fri_proof.inner_layers[i].evals_subset.len() { - save_secure_field_element(&mut file,proof.commitment_scheme_proof.fri_proof.inner_layers[i].evals_subset[j])?; - if proof.commitment_scheme_proof.fri_proof.inner_layers[i].evals_subset.len() != j+1 { - file.write_all(b",\n\t\t")?; - } - } - if i != 5 { - file.write_all(b"],\n\n")?; - } else { - file.write_all(b"]\n}")?; - } - } - Ok(()) -} - -pub fn compressed_save_poseidon_bls_proof(proof: &StarkProof) -> Result<(),Error> { - let mut file = File::create("proof.json")?; - file.write_all(b"{")?; - //commitments - file.write_all(b"\"commitments\":[")?; - for i in 0..proof.commitments.len() { - file.write_all(b"\"")?; - file.write_all(&proof.commitments[i].to_string().into_bytes())?; - file.write_all(b"\"")?; - if proof.commitments.len() != i+1 { - file.write_all(b",")?; - } - } - file.write_all(b"],")?; - // Sampled Values - file.write_all(b"\"sampled_values_0\":[")?; - for i in 0..proof.commitment_scheme_proof.sampled_values.0[0].len() { - save_secure_field_element(&mut file,proof.commitment_scheme_proof.sampled_values.0[0][i][0])?; - if proof.commitment_scheme_proof.sampled_values.0[0].len() != i+1 { - file.write_all(b",")?; - } - } - file.write_all(b"],")?; - file.write_all(b"\"sampled_values_1\":[")?; - for i in 0..proof.commitment_scheme_proof.sampled_values.0[1].len() { - save_secure_field_element(&mut file,proof.commitment_scheme_proof.sampled_values.0[1][i][0])?; - if proof.commitment_scheme_proof.sampled_values.0[1].len() != i+1 { - file.write_all(b",")?; - } - } - file.write_all(b"],")?; - - //decommitments - file.write_all(b"\"decommitment_0\":[")?; - for i in 0..proof.commitment_scheme_proof.decommitments.0[0].hash_witness.len() { - file.write_all(b"\"")?; - file.write_all(&proof.commitment_scheme_proof.decommitments.0[0].hash_witness[i].to_string().into_bytes())?; - file.write_all(b"\"")?; - if proof.commitment_scheme_proof.decommitments.0[0].hash_witness.len() != i+1 { - file.write_all(b",")?; - } - } - file.write_all(b"],")?; - file.write_all(b"\"decommitment_1\":[")?; - for i in 0..proof.commitment_scheme_proof.decommitments.0[1].hash_witness.len() { - file.write_all(b"\"")?; - file.write_all(&proof.commitment_scheme_proof.decommitments.0[1].hash_witness[i].to_string().into_bytes())?; - file.write_all(b"\"")?; - if proof.commitment_scheme_proof.decommitments.0[1].hash_witness.len() != i+1 { - file.write_all(b",")?; - } - } - file.write_all(b"],")?; - - //Queried_values - file.write_all(b"\"queried_values_0\":[")?; - for i in 0..proof.commitment_scheme_proof.queried_values.0[0].len() { - file.write_all(b"[")?; - for j in 0..6 { - file.write_all(b"\"")?; - file.write_all(&proof.commitment_scheme_proof.queried_values.0[0][i][j].to_string().into_bytes())?; - file.write_all(b"\"")?; - if j != 5 { - file.write_all(b",")?; - } else { file.write_all(b"]")?; } - } - if proof.commitment_scheme_proof.queried_values.0[0].len() != i+1 { - file.write_all(b",")?; - } - } - file.write_all(b"],")?; - file.write_all(b"\"queried_values_1\":[")?; - for i in 0..proof.commitment_scheme_proof.queried_values.0[1].len() { - file.write_all(b"[")?; - for j in 0..6 { - file.write_all(b"\"")?; - file.write_all(&proof.commitment_scheme_proof.queried_values.0[1][i][j].to_string().into_bytes())?; - file.write_all(b"\"")?; - if j != 5 { - file.write_all(b",")?; - } else { file.write_all(b"]")?; } - } - if proof.commitment_scheme_proof.queried_values.0[1].len() != i+1 { - file.write_all(b",")?; - } - } - file.write_all(b"],")?; - - //proof of work - file.write_all(b"\"proof of work\":\"")?; - file.write_all(&proof.commitment_scheme_proof.proof_of_work.to_string().into_bytes())?; - file.write_all(b"\",")?; - - //last FRI layer coeffs - file.write_all(b"\"coeffs\":")?; - save_secure_field_element(&mut file,proof.commitment_scheme_proof.fri_proof.last_layer_poly.coeffs[0] )?; - file.write_all(b",")?; - - //intermediate FRI layers - for i in 0..6 { - //commitment - file.write_all(b"\"inner_commitment_")?; - file.write_all(&i.to_string().into_bytes())?; - file.write_all(b"\":\"")?; - file.write_all(&proof.commitment_scheme_proof.fri_proof.inner_layers[i].commitment.to_string().into_bytes())?; - file.write_all(b"\",")?; - - //decommitment - file.write_all(b"\"inner_decommitment_")?; - file.write_all(&i.to_string().into_bytes())?; - file.write_all(b"\":[")?; - for j in 0..proof.commitment_scheme_proof.fri_proof.inner_layers[i].decommitment.hash_witness.len() { - file.write_all(b"\"")?; - file.write_all(&proof.commitment_scheme_proof.fri_proof.inner_layers[i].decommitment.hash_witness[j].to_string().into_bytes())?; - file.write_all(b"\"")?; - if proof.commitment_scheme_proof.fri_proof.inner_layers[i].decommitment.hash_witness.len() != j+1 { - file.write_all(b",")?; - } - } - file.write_all(b"],")?; - - //evals_subset - file.write_all(b"\"inner_evals_subset_")?; - file.write_all(&i.to_string().into_bytes())?; - file.write_all(b"\":[")?; - for j in 0..proof.commitment_scheme_proof.fri_proof.inner_layers[i].evals_subset.len() { - save_secure_field_element(&mut file,proof.commitment_scheme_proof.fri_proof.inner_layers[i].evals_subset[j])?; - if proof.commitment_scheme_proof.fri_proof.inner_layers[i].evals_subset.len() != j+1 { - file.write_all(b",")?; - } - } - if i != 5 { - file.write_all(b"],")?; - } else { - file.write_all(b"]}")?; - } - } - - - - - Ok(()) -} - -#[cfg(test)] -mod tests { - use itertools::Itertools; - use num_traits::One; - - use super::{pretty_save_poseidon_bls_proof, WideFibonacciEval}; - use crate::constraint_framework::{ - assert_constraints, AssertEvaluator, FrameworkEval, TraceLocationAllocator, - }; - use crate::core::air::Component; - use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; - use crate::core::backend::simd::SimdBackend; - use crate::core::backend::Column; - use crate::core::channel::Blake2sChannel; - #[cfg(not(target_arch = "wasm32"))] - use crate::core::channel::Poseidon252Channel; - #[cfg(not(target_arch = "wasm32"))] - use crate::core::channel::PoseidonBLSChannel; - use crate::core::fields::m31::BaseField; - use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig, TreeVec}; - use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; - use crate::core::poly::BitReversedOrder; - use crate::core::prover::{prove, verify}; - use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; - #[cfg(not(target_arch = "wasm32"))] - use crate::core::vcs::poseidon252_merkle::Poseidon252MerkleChannel; - #[cfg(not(target_arch = "wasm32"))] - use crate::core::vcs::poseidon_bls_merkle::PoseidonBLSMerkleChannel; - use crate::core::ColumnVec; - use crate::examples::wide_fibonacci::{generate_trace, FibInput, WideFibonacciComponent}; - - const FIB_SEQUENCE_LENGTH: usize = 100; - - fn generate_test_trace( - log_n_instances: u32, - ) -> ColumnVec> { - let inputs = (0..(1 << (log_n_instances - LOG_N_LANES))) - .map(|i| FibInput { - a: PackedBaseField::one(), - b: PackedBaseField::from_array(std::array::from_fn(|j| { - BaseField::from_u32_unchecked((i * 16 + j) as u32) - })), - }) - .collect_vec(); - generate_trace::(log_n_instances, &inputs) - } - - fn fibonacci_constraint_evaluator(eval: AssertEvaluator<'_>) { - WideFibonacciEval:: { log_n_rows: N }.evaluate(eval); - } - - #[test] - fn test_wide_fibonacci_constraints() { - const LOG_N_INSTANCES: u32 = 6; - let traces = TreeVec::new(vec![generate_test_trace(LOG_N_INSTANCES)]); - let trace_polys = - traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect_vec()); - - assert_constraints( - &trace_polys, - CanonicCoset::new(LOG_N_INSTANCES), - fibonacci_constraint_evaluator::, - ); - } - - #[test] - #[should_panic] - fn test_wide_fibonacci_constraints_fails() { - const LOG_N_INSTANCES: u32 = 6; - - let mut trace = generate_test_trace(LOG_N_INSTANCES); - // Modify the trace such that a constraint fail. - trace[17].values.set(2, BaseField::one()); - let traces = TreeVec::new(vec![trace]); - let trace_polys = - traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect_vec()); - - assert_constraints( - &trace_polys, - CanonicCoset::new(LOG_N_INSTANCES), - fibonacci_constraint_evaluator::, - ); - } - - #[test_log::test] - fn test_wide_fib_prove() { - const LOG_N_INSTANCES: u32 = 6; - let config = PcsConfig::default(); - // Precompute twiddles. - let twiddles = SimdBackend::precompute_twiddles( - CanonicCoset::new(LOG_N_INSTANCES + 1 + config.fri_config.log_blowup_factor) - .circle_domain() - .half_coset, - ); - - // Setup protocol. - let prover_channel = &mut Blake2sChannel::default(); - let commitment_scheme = - &mut CommitmentSchemeProver::::new( - config, &twiddles, - ); - - // Trace. - let trace = generate_test_trace(LOG_N_INSTANCES); - let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals(trace); - tree_builder.commit(prover_channel); - - // Prove constraints. - let component = WideFibonacciComponent::new( - &mut TraceLocationAllocator::default(), - WideFibonacciEval:: { - log_n_rows: LOG_N_INSTANCES, - }, - ); - - let proof = prove::( - &[&component], - prover_channel, - commitment_scheme, - ) - .unwrap(); - - // Verify. - let verifier_channel = &mut Blake2sChannel::default(); - let commitment_scheme = &mut CommitmentSchemeVerifier::::new(config); - - // Retrieve the expected column sizes in each commitment interaction, from the AIR. - let sizes = component.trace_log_degree_bounds(); - commitment_scheme.commit(proof.commitments[0], &sizes[0], verifier_channel); - verify(&[&component], verifier_channel, commitment_scheme, proof).unwrap(); - } - - #[test] - #[cfg(not(target_arch = "wasm32"))] - fn test_wide_fib_prove_with_poseidon() { - const LOG_N_INSTANCES: u32 = 6; - - let config = PcsConfig::default(); - // Precompute twiddles. - let twiddles = SimdBackend::precompute_twiddles( - CanonicCoset::new(LOG_N_INSTANCES + 1 + config.fri_config.log_blowup_factor) - .circle_domain() - .half_coset, - ); - - // Setup protocol. - let prover_channel = &mut Poseidon252Channel::default(); - let commitment_scheme = - &mut CommitmentSchemeProver::::new( - config, &twiddles, - ); - - // Trace. - let trace = generate_test_trace(LOG_N_INSTANCES); - //Initialize the parameters of the trace - let mut tree_builder = commitment_scheme.tree_builder(); - // Interpolation of the columns - tree_builder.extend_evals(trace); - // Compute the evaluations of the polynomials, build a Merkle tree of the evaluation and - // update the channel (Fiat-Shamir) with the root of the tree (Transcript <-- root) - tree_builder.commit(prover_channel); - - // Prove constraints. - let component = WideFibonacciComponent::new( - &mut TraceLocationAllocator::default(), - WideFibonacciEval:: { - log_n_rows: LOG_N_INSTANCES, - }, - ); - let proof = prove::( - &[&component], - prover_channel, - commitment_scheme, - ) - .unwrap(); - - // Verify. - let verifier_channel = &mut Poseidon252Channel::default(); - let commitment_scheme = - &mut CommitmentSchemeVerifier::::new(config); - - // Retrieve the expected column sizes in each commitment interaction, from the AIR. - let sizes = component.trace_log_degree_bounds(); - commitment_scheme.commit(proof.commitments[0], &sizes[0], verifier_channel); - verify(&[&component], verifier_channel, commitment_scheme, proof).unwrap(); - } - - #[test] - #[cfg(not(target_arch = "wasm32"))] - fn test_wide_fib_prove_with_poseidon_bls() { - - const LOG_N_INSTANCES: u32 = 6; - - let config = PcsConfig::default(); - // Precompute twiddles. - let twiddles = SimdBackend::precompute_twiddles( - CanonicCoset::new(LOG_N_INSTANCES + 1 + config.fri_config.log_blowup_factor) - .circle_domain() - .half_coset, - ); - - // Setup protocol. - let prover_channel = &mut PoseidonBLSChannel::default(); - let commitment_scheme = - &mut CommitmentSchemeProver::::new( - config, &twiddles, - ); - - // Trace. - let trace = generate_test_trace(LOG_N_INSTANCES); - //Initialize the parameters of the trace - let mut tree_builder = commitment_scheme.tree_builder(); - // Interpolation of the columns - tree_builder.extend_evals(trace); - // Compute the evaluations of the polynomials, build a Merkle tree of the evaluation and - // update the channel (Fiat-Shamir) with the root of the tree (Transcript <-- root) - tree_builder.commit(prover_channel); - - // Prove constraints. - let component = WideFibonacciComponent::new( - &mut TraceLocationAllocator::default(), - WideFibonacciEval:: { - log_n_rows: LOG_N_INSTANCES, - }, - ); - let proof = prove::( - &[&component], - prover_channel, - commitment_scheme, - ) - .unwrap(); - _ = pretty_save_poseidon_bls_proof(&proof); - - // Verify. - let verifier_channel = &mut PoseidonBLSChannel::default(); - let commitment_scheme = - &mut CommitmentSchemeVerifier::::new(config); - - // Retrieve the expected column sizes in each commitment interaction, from the AIR. - let sizes = component.trace_log_degree_bounds(); - commitment_scheme.commit(proof.commitments[0], &sizes[0], verifier_channel); - verify(&[&component], verifier_channel, commitment_scheme, proof).unwrap(); - } -} - diff --git a/Stwo_wrapper/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs b/Stwo_wrapper/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs deleted file mode 100644 index 53ae956..0000000 --- a/Stwo_wrapper/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs +++ /dev/null @@ -1,186 +0,0 @@ -use std::iter::zip; -use std::ops::{AddAssign, Mul}; - -use educe::Educe; -use num_traits::One; - -use crate::core::backend::simd::SimdBackend; -use crate::core::backend::Backend; -use crate::core::circle::M31_CIRCLE_LOG_ORDER; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::lookups::mle::Mle; -use crate::core::utils::generate_secure_powers; - -pub const MIN_LOG_BLOWUP_FACTOR: u32 = 1; - -/// Max number of variables for multilinear polynomials that get compiled into a univariate -/// IOP for multilinear eval at point. -pub const MAX_MLE_N_VARIABLES: u32 = M31_CIRCLE_LOG_ORDER - MIN_LOG_BLOWUP_FACTOR; - -/// Accumulates [`Mle`]s grouped by their number of variables. -pub struct MleCollection { - mles_by_n_variables: Vec>>>, -} - -impl MleCollection { - /// Appends an [`Mle`] to the collection. - pub fn push(&mut self, mle: impl Into>) { - let mle = mle.into(); - let mles = self.mles_by_n_variables[mle.n_variables()].get_or_insert(Vec::new()); - mles.push(mle); - } -} - -impl MleCollection { - /// Performs a random linear combination of all MLEs, grouped by their number of variables. - /// - /// MLEs are returned in ascending order by number of variables. - pub fn random_linear_combine_by_n_variables( - self, - alpha: SecureField, - ) -> Vec> { - self.mles_by_n_variables - .into_iter() - .flatten() - .map(|mles| mle_random_linear_combination(mles, alpha)) - .collect() - } -} - -/// # Panics -/// -/// Panics if `mles` is empty or all MLEs don't have the same number of variables. -fn mle_random_linear_combination( - mles: Vec>, - alpha: SecureField, -) -> Mle { - assert!(!mles.is_empty()); - let n_variables = mles[0].n_variables(); - assert!(mles.iter().all(|mle| mle.n_variables() == n_variables)); - let alpha_powers = generate_secure_powers(alpha, mles.len()).into_iter().rev(); - let mut mle_and_coeff = zip(mles, alpha_powers); - - // The last value can initialize the accumulator. - let (mle, coeff) = mle_and_coeff.next_back().unwrap(); - assert!(coeff.is_one()); - let mut acc_mle = mle.into_secure_mle(); - - for (mle, coeff) in mle_and_coeff { - match mle { - DynMle::Base(mle) => combine(&mut acc_mle.data, &mle.data, coeff.into()), - DynMle::Secure(mle) => combine(&mut acc_mle.data, &mle.data, coeff.into()), - } - } - - acc_mle -} - -/// Computes all `acc[i] += alpha * v[i]`. -pub fn combine + Copy, F: Copy>( - acc: &mut [EF], - v: &[F], - alpha: EF, -) { - assert_eq!(acc.len(), v.len()); - zip(acc, v).for_each(|(acc, &v)| *acc += alpha * v); -} - -impl Default for MleCollection { - fn default() -> Self { - Self { - mles_by_n_variables: vec![None; MAX_MLE_N_VARIABLES as usize + 1], - } - } -} - -/// Dynamic dispatch for [`Mle`] types. -#[derive(Educe)] -#[educe(Debug, Clone)] -pub enum DynMle { - Base(Mle), - Secure(Mle), -} - -impl DynMle { - fn n_variables(&self) -> usize { - match self { - DynMle::Base(mle) => mle.n_variables(), - DynMle::Secure(mle) => mle.n_variables(), - } - } -} - -impl From> for DynMle { - fn from(mle: Mle) -> Self { - DynMle::Secure(mle) - } -} - -impl From> for DynMle { - fn from(mle: Mle) -> Self { - DynMle::Base(mle) - } -} - -impl DynMle { - fn into_secure_mle(self) -> Mle { - match self { - Self::Base(mle) => Mle::new(mle.into_evals().into_secure_column()), - Self::Secure(mle) => mle, - } - } -} - -#[cfg(test)] -mod tests { - use std::iter::repeat; - - use num_traits::Zero; - - use crate::core::backend::simd::SimdBackend; - use crate::core::fields::m31::BaseField; - use crate::core::fields::qm31::SecureField; - use crate::core::fields::Field; - use crate::core::lookups::mle::{Mle, MleOps}; - use crate::examples::xor::gkr_lookups::accumulation::MleCollection; - - #[test] - fn random_linear_combine_by_n_variables() { - const SMALL_N_VARS: usize = 4; - const LARGE_N_VARS: usize = 6; - let alpha = SecureField::from(10); - let mut mle_collection = MleCollection::::default(); - mle_collection.push(const_mle(SMALL_N_VARS, BaseField::from(1))); - mle_collection.push(const_mle(SMALL_N_VARS, SecureField::from(2))); - mle_collection.push(const_mle(LARGE_N_VARS, BaseField::from(3))); - mle_collection.push(const_mle(LARGE_N_VARS, SecureField::from(4))); - mle_collection.push(const_mle(LARGE_N_VARS, SecureField::from(5))); - let small_eval_point = [SecureField::zero(); SMALL_N_VARS]; - let large_eval_point = [SecureField::zero(); LARGE_N_VARS]; - - let [small_mle, large_mle] = mle_collection - .random_linear_combine_by_n_variables(alpha) - .try_into() - .unwrap(); - - assert_eq!(small_mle.n_variables(), SMALL_N_VARS); - assert_eq!(large_mle.n_variables(), LARGE_N_VARS); - assert_eq!( - small_mle.eval_at_point(&small_eval_point), - SecureField::from(1) * alpha + SecureField::from(2) - ); - assert_eq!( - large_mle.eval_at_point(&large_eval_point), - (SecureField::from(3) * alpha + SecureField::from(4)) * alpha + SecureField::from(5) - ); - } - - fn const_mle(n_variables: usize, v: F) -> Mle - where - B: MleOps, - F: Field, - { - Mle::new(repeat(v).take(1 << n_variables).collect()) - } -} diff --git a/Stwo_wrapper/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs b/Stwo_wrapper/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs deleted file mode 100644 index 5a5d605..0000000 --- a/Stwo_wrapper/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs +++ /dev/null @@ -1,571 +0,0 @@ -//! Multilinear extension (MLE) eval at point constraints. -// TODO(andrew): Remove in downstream PR. -#![allow(dead_code)] - -use std::array; - -use itertools::Itertools; -use num_traits::{One, Zero}; - -use crate::constraint_framework::EvalAtRow; -use crate::core::backend::simd::SimdBackend; -use crate::core::circle::{CirclePoint, Coset}; -use crate::core::constraints::{coset_vanishing, point_vanishing}; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::secure_column::SecureColumnByCoords; -use crate::core::fields::{Field, FieldExpOps}; -use crate::core::lookups::utils::eq; -use crate::core::poly::circle::{CanonicCoset, SecureEvaluation}; -use crate::core::poly::BitReversedOrder; -use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index}; - -/// Evaluates constraints that guarantee an MLE evaluates to a claim at a given point. -/// -/// `mle_coeffs_col_eval` should be the evaluation of the column containing the coefficients of the -/// MLE in the multilinear Lagrange basis. `mle_claim_shift` should equal `claim / 2^N_VARIABLES`. -pub fn eval_mle_eval_constraints( - mle_interaction: usize, - const_interaction: usize, - eval: &mut E, - mle_coeffs_col_eval: E::EF, - mle_eval_point: MleEvalPoint, - mle_claim_shift: SecureField, - carry_quotients_col_eval: E::EF, -) { - let eq_col_eval = eval_eq_constraints( - mle_interaction, - const_interaction, - eval, - mle_eval_point, - carry_quotients_col_eval, - ); - let terms_col_eval = mle_coeffs_col_eval * eq_col_eval; - eval_prefix_sum_constraints(mle_interaction, eval, terms_col_eval, mle_claim_shift) -} - -#[derive(Debug, Clone, Copy)] -pub struct MleEvalPoint { - // Equals `eq({0}^|p|, p)`. - eq_0_p: SecureField, - // Equals `eq({1}^|p|, p)`. - eq_1_p: SecureField, - // Index `i` stores `eq(({1}^|i|, 0), p[0..i+1]) / eq(({0}^|i|, 1), p[0..i+1])`. - eq_carry_quotients: [SecureField; N_VARIABLES], - // Point `p`. - p: [SecureField; N_VARIABLES], -} - -impl MleEvalPoint { - /// Creates new metadata from point `p`. - pub fn new(p: [SecureField; N_VARIABLES]) -> Self { - let zero = SecureField::zero(); - let one = SecureField::one(); - - Self { - eq_0_p: eq(&[zero; N_VARIABLES], &p), - eq_1_p: eq(&[one; N_VARIABLES], &p), - eq_carry_quotients: array::from_fn(|i| { - let mut numer_assignment = vec![one; i + 1]; - numer_assignment[i] = zero; - let mut denom_assignment = vec![zero; i + 1]; - denom_assignment[i] = one; - eq(&numer_assignment, &p[..i + 1]) / eq(&denom_assignment, &p[..i + 1]) - }), - p, - } - } -} - -/// Evaluates EqEvals constraints on a column. -/// -/// Returns the evaluation at offset 0 on the column. -/// -/// Given a column `c(P)` defined on a circle domain `D`, and an MLE eval point `(r0, r1, ...)` -/// evaluates constraints that guarantee: `c(D[b0, b1, ...]) = eq((b0, b1, ...), (r0, r1, ...))`. -/// -/// See (Section 5.1). -fn eval_eq_constraints( - eq_interaction: usize, - const_interaction: usize, - eval: &mut E, - mle_eval_point: MleEvalPoint, - carry_quotients_col_eval: E::EF, -) -> E::EF { - let [curr, next_next] = eval.next_extension_interaction_mask(eq_interaction, [0, 2]); - let [is_first, is_second] = eval.next_interaction_mask(const_interaction, [0, -1]); - - // Check the initial value on half_coset0 and final value on half_coset1. - // Combining these constraints is safe because `is_first` and `is_second` are never - // non-zero at the same time on the trace. - let half_coset0_initial_check = (curr - mle_eval_point.eq_0_p) * is_first; - let half_coset1_final_check = (curr - mle_eval_point.eq_1_p) * is_second; - eval.add_constraint(half_coset0_initial_check + half_coset1_final_check); - - // Check all the steps. - eval.add_constraint(curr - next_next * carry_quotients_col_eval); - - curr -} - -/// Evaluates inclusive prefix sum constraints on a column. -/// -/// Note the column values must be shifted by `cumulative_sum_shift` so the last value equals zero. -/// `cumulative_sum_shift` should equal `cumulative_sum / column_size`. -fn eval_prefix_sum_constraints( - interaction: usize, - eval: &mut E, - row_diff: E::EF, - cumulative_sum_shift: SecureField, -) { - let [curr, prev] = eval.next_extension_interaction_mask(interaction, [0, -1]); - eval.add_constraint(curr - prev - row_diff + cumulative_sum_shift); -} - -/// Returns succinct Eq carry quotients column. -/// -/// Given column `c(P)` defined on a [`CircleDomain`] `D = +-C`, and an MLE eval point -/// `(r0, r1, ...)` let `c(D[b0, b1, ...]) = eq((b0, b1, ...), (r0, r1, ...))`. This function -/// returns column `q(P)` such that all `c(C[i]) = c(C[i + 1]) * q(C[i])` and -/// `c(-C[i]) = c(-C[i + 1]) * q(-C[i])`. -/// -/// [`CircleDomain`]: crate::core::poly::circle::CircleDomain -fn gen_carry_quotient_col( - eval_point: &MleEvalPoint, -) -> SecureEvaluation { - let (half_coset0_carry_quotients, half_coset1_carry_quotients) = - gen_half_coset_carry_quotients(eval_point); - - let log_size = N_VARIABLES as u32; - let size = 1 << log_size; - let half_coset_size = size / 2; - let mut col = SecureColumnByCoords::::zeros(size); - - // TODO(andrew): Optimize. - for i in 0..half_coset_size { - let half_coset0_index = coset_index_to_circle_domain_index(i * 2, log_size); - let half_coset1_index = coset_index_to_circle_domain_index(i * 2 + 1, log_size); - let half_coset0_index_bit_rev = bit_reverse_index(half_coset0_index, log_size); - let half_coset1_index_bit_rev = bit_reverse_index(half_coset1_index, log_size); - - let n_trailing_ones = i.trailing_ones() as usize; - let half_coset0_carry_quotient = half_coset0_carry_quotients[n_trailing_ones]; - let half_coset1_carry_quotient = half_coset1_carry_quotients[n_trailing_ones]; - - col.set(half_coset0_index_bit_rev, half_coset0_carry_quotient); - col.set(half_coset1_index_bit_rev, half_coset1_carry_quotient); - } - - let domain = CanonicCoset::new(log_size).circle_domain(); - SecureEvaluation::new(domain, col) -} - -/// Evaluates the succinct Eq carry quotients column at point `p`. -/// -/// See [`gen_carry_quotient_col`]. -// TODO(andrew): Optimize further. Inline `eval_step_selector` and get runtime down to -// O(N_VARIABLES) vs current O(N_VARIABLES^2). Can also use vanishing evals to compute -// half_coset0_last half_coset1_first. -fn eval_carry_quotient_col( - eval_point: &MleEvalPoint, - p: CirclePoint, -) -> SecureField { - let log_size = N_VARIABLES as u32; - let coset = CanonicCoset::new(log_size).coset(); - - let (half_coset0_carry_quotients, half_coset1_carry_quotients) = - gen_half_coset_carry_quotients(eval_point); - - let mut eval = SecureField::zero(); - - for variable_i in 0..N_VARIABLES.saturating_sub(1) { - let log_step = variable_i as u32 + 2; - let offset = (1 << (log_step - 1)) - 2; - let half_coset0_selector = eval_step_selector_with_offset(coset, offset, log_step, p); - let half_coset1_selector = eval_step_selector_with_offset(coset, offset + 1, log_step, p); - let half_coset0_carry_quotient = half_coset0_carry_quotients[variable_i]; - let half_coset1_carry_quotient = half_coset1_carry_quotients[variable_i]; - eval += half_coset0_selector * half_coset0_carry_quotient; - eval += half_coset1_selector * half_coset1_carry_quotient; - } - - let half_coset0_last = eval_is_first(coset, p + coset.step.double().into_ef()); - let half_coset1_first = eval_is_first(coset, p + coset.step.into_ef()); - eval += *half_coset0_carry_quotients.last().unwrap() * half_coset0_last; - eval += *half_coset1_carry_quotients.last().unwrap() * half_coset1_first; - - eval -} - -/// Evaluates a polynomial that's `1` every `2^log_step` coset points, shifted by an offset, and `0` -/// elsewhere on coset. -fn eval_step_selector_with_offset( - coset: Coset, - offset: usize, - log_step: u32, - p: CirclePoint, -) -> SecureField { - let offset_step = coset.step.mul(offset as u128); - eval_step_selector(coset, log_step, p - offset_step.into_ef()) -} - -/// Evaluates a polynomial that's `1` every `2^log_step` coset points and `0` elsewhere on coset. -fn eval_step_selector(coset: Coset, log_step: u32, p: CirclePoint) -> SecureField { - if log_step == 0 { - return SecureField::one(); - } - - // Rotate the coset to have points on the `x` axis. - let p = p - coset.initial.into_ef(); - let mut vanish_at_log_step = (0..coset.log_size) - .scan(p, |p, _| { - let res = *p; - *p = p.double(); - Some(res.y) - }) - .collect_vec(); - vanish_at_log_step.reverse(); - // We only need the first `log_step` many values. - vanish_at_log_step.truncate(log_step as usize); - let mut vanish_at_log_step_inv = vec![SecureField::zero(); vanish_at_log_step.len()]; - SecureField::batch_inverse(&vanish_at_log_step, &mut vanish_at_log_step_inv); - - let half_coset_selector_dbl = (vanish_at_log_step[0] * vanish_at_log_step_inv[1]).square(); - let vanish_substep_inv_sum = vanish_at_log_step_inv[1..].iter().sum::(); - (half_coset_selector_dbl + vanish_at_log_step[0] * vanish_substep_inv_sum.double()) - / BaseField::from(1 << (log_step + 1)) -} - -fn eval_is_first(coset: Coset, p: CirclePoint) -> SecureField { - coset_vanishing(coset, p) - / (point_vanishing(coset.initial, p) * BaseField::from(1 << coset.log_size)) -} - -/// Output of the form: `(half_coset0_carry_quotients, half_coset1_carry_quotients)`. -fn gen_half_coset_carry_quotients( - eval_point: &MleEvalPoint, -) -> ([SecureField; N_VARIABLES], [SecureField; N_VARIABLES]) { - let last_variable = *eval_point.p.last().unwrap(); - let mut half_coset0_carry_quotients = eval_point.eq_carry_quotients; - *half_coset0_carry_quotients.last_mut().unwrap() *= - eq(&[SecureField::one()], &[last_variable]) / eq(&[SecureField::zero()], &[last_variable]); - let half_coset1_carry_quotients = half_coset0_carry_quotients.map(|v| v.inverse()); - (half_coset0_carry_quotients, half_coset1_carry_quotients) -} - -#[cfg(test)] -mod tests { - use std::array; - use std::iter::{repeat, zip}; - - use itertools::{chain, zip_eq, Itertools}; - use num_traits::One; - use rand::rngs::SmallRng; - use rand::{Rng, SeedableRng}; - - use super::{ - eval_carry_quotient_col, eval_eq_constraints, eval_mle_eval_constraints, - eval_prefix_sum_constraints, gen_carry_quotient_col, MleEvalPoint, - }; - use crate::constraint_framework::constant_columns::{gen_is_first, gen_is_step_with_offset}; - use crate::constraint_framework::{assert_constraints, EvalAtRow}; - use crate::core::backend::simd::column::SecureColumn; - use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum; - use crate::core::backend::simd::qm31::PackedSecureField; - use crate::core::backend::simd::SimdBackend; - use crate::core::backend::{Col, Column}; - use crate::core::circle::SECURE_FIELD_CIRCLE_GEN; - use crate::core::fields::m31::BaseField; - use crate::core::fields::qm31::SecureField; - use crate::core::fields::secure_column::SecureColumnByCoords; - use crate::core::lookups::gkr_prover::GkrOps; - use crate::core::lookups::mle::Mle; - use crate::core::pcs::TreeVec; - use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; - use crate::core::poly::BitReversedOrder; - use crate::core::utils::{bit_reverse, coset_order_to_circle_domain_order}; - use crate::examples::xor::gkr_lookups::mle_eval::eval_step_selector_with_offset; - - #[test] - fn test_mle_eval_constraints_with_log_size_5() { - const N_VARIABLES: usize = 5; - const EVAL_TRACE: usize = 0; - const CARRY_QUOTIENTS_TRACE: usize = 1; - const CONST_TRACE: usize = 2; - let mut rng = SmallRng::seed_from_u64(0); - let log_size = N_VARIABLES as u32; - let size = 1 << log_size; - let mle = Mle::new((0..size).map(|_| rng.gen::()).collect()); - let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); - let mle_eval_point = MleEvalPoint::new(eval_point); - let base_trace = gen_base_trace(&mle, &eval_point); - let claim = mle.eval_at_point(&eval_point); - let claim_shift = claim / BaseField::from(size); - let carry_quotients_col = gen_carry_quotient_col(&mle_eval_point) - .into_coordinate_evals() - .to_vec(); - let constants_trace = gen_constants_trace::(); - let traces = TreeVec::new(vec![base_trace, carry_quotients_col, constants_trace]); - let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); - let trace_domain = CanonicCoset::new(log_size); - - assert_constraints(&trace_polys, trace_domain, |mut eval| { - let [mle_coeff_col_eval] = eval.next_extension_interaction_mask(EVAL_TRACE, [0]); - let [carry_quotients_col_eval] = - eval.next_extension_interaction_mask(CARRY_QUOTIENTS_TRACE, [0]); - eval_mle_eval_constraints( - EVAL_TRACE, - CONST_TRACE, - &mut eval, - mle_coeff_col_eval, - mle_eval_point, - claim_shift, - carry_quotients_col_eval, - ) - }); - } - - #[test] - #[ignore = "SimdBackend `MIN_FFT_LOG_SIZE` is 5"] - fn eq_constraints_with_4_variables() { - const N_VARIABLES: usize = 4; - const EVAL_TRACE: usize = 0; - const CARRY_QUOTIENTS_TRACE: usize = 1; - const CONST_TRACE: usize = 2; - let mut rng = SmallRng::seed_from_u64(0); - let mle = Mle::new(repeat(SecureField::one()).take(1 << N_VARIABLES).collect()); - let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); - let mle_eval_point = MleEvalPoint::new(eval_point); - let base_trace = gen_base_trace(&mle, &eval_point); - let carry_quotients_col = gen_carry_quotient_col(&mle_eval_point) - .into_coordinate_evals() - .to_vec(); - let constants_trace = gen_constants_trace::(); - let traces = TreeVec::new(vec![base_trace, carry_quotients_col, constants_trace]); - let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); - let trace_domain = CanonicCoset::new(eval_point.len() as u32); - - assert_constraints(&trace_polys, trace_domain, |mut eval| { - let _mle_coeffs_col_eval = eval.next_extension_interaction_mask(EVAL_TRACE, [0]); - let [carry_quotients_col_eval] = - eval.next_extension_interaction_mask(CARRY_QUOTIENTS_TRACE, [0]); - eval_eq_constraints( - EVAL_TRACE, - CONST_TRACE, - &mut eval, - mle_eval_point, - carry_quotients_col_eval, - ); - }); - } - - #[test] - fn eq_constraints_with_5_variables() { - const N_VARIABLES: usize = 5; - const EVAL_TRACE: usize = 0; - const CARRY_QUOTIENTS_TRACE: usize = 1; - const CONST_TRACE: usize = 2; - let mut rng = SmallRng::seed_from_u64(0); - let mle = Mle::new(repeat(SecureField::one()).take(1 << N_VARIABLES).collect()); - let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); - let mle_eval_point = MleEvalPoint::new(eval_point); - let base_trace = gen_base_trace(&mle, &eval_point); - let carry_quotients_col = gen_carry_quotient_col(&mle_eval_point) - .into_coordinate_evals() - .to_vec(); - let constants_trace = gen_constants_trace::(); - let traces = TreeVec::new(vec![base_trace, carry_quotients_col, constants_trace]); - let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); - let trace_domain = CanonicCoset::new(eval_point.len() as u32); - - assert_constraints(&trace_polys, trace_domain, |mut eval| { - let _mle_coeffs_col_eval = eval.next_extension_interaction_mask(EVAL_TRACE, [0]); - let [carry_quotients_col_eval] = - eval.next_extension_interaction_mask(CARRY_QUOTIENTS_TRACE, [0]); - eval_eq_constraints( - EVAL_TRACE, - CONST_TRACE, - &mut eval, - mle_eval_point, - carry_quotients_col_eval, - ); - }); - } - - #[test] - fn eq_constraints_with_8_variables() { - const N_VARIABLES: usize = 8; - const EVAL_TRACE: usize = 0; - const CARRY_QUOTIENTS_TRACE: usize = 1; - const CONST_TRACE: usize = 2; - let mut rng = SmallRng::seed_from_u64(0); - let mle = Mle::new(repeat(SecureField::one()).take(1 << N_VARIABLES).collect()); - let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); - let mle_eval_point = MleEvalPoint::new(eval_point); - let base_trace = gen_base_trace(&mle, &eval_point); - let carry_quotients_col = gen_carry_quotient_col(&mle_eval_point) - .into_coordinate_evals() - .to_vec(); - let constants_trace = gen_constants_trace::(); - let traces = TreeVec::new(vec![base_trace, carry_quotients_col, constants_trace]); - let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); - let trace_domain = CanonicCoset::new(eval_point.len() as u32); - - assert_constraints(&trace_polys, trace_domain, |mut eval| { - let _mle_coeffs_col_eval = eval.next_extension_interaction_mask(EVAL_TRACE, [0]); - let [carry_quotients_col_eval] = - eval.next_extension_interaction_mask(CARRY_QUOTIENTS_TRACE, [0]); - eval_eq_constraints( - EVAL_TRACE, - CONST_TRACE, - &mut eval, - mle_eval_point, - carry_quotients_col_eval, - ); - }); - } - - #[test] - fn inclusive_prefix_sum_constraints_with_log_size_5() { - const LOG_SIZE: u32 = 5; - let mut rng = SmallRng::seed_from_u64(0); - let vals = (0..1 << LOG_SIZE).map(|_| rng.gen()).collect_vec(); - let cumulative_sum = vals.iter().sum::(); - let cumulative_sum_shift = cumulative_sum / BaseField::from(vals.len()); - let trace = TreeVec::new(vec![gen_prefix_sum_trace(vals)]); - let trace_polys = trace.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); - let trace_domain = CanonicCoset::new(LOG_SIZE); - - assert_constraints(&trace_polys, trace_domain, |mut eval| { - let [row_diff] = eval.next_extension_interaction_mask(0, [0]); - eval_prefix_sum_constraints(0, &mut eval, row_diff, cumulative_sum_shift) - }); - } - - #[test] - fn eval_step_selector_with_offset_works() { - const LOG_SIZE: u32 = 5; - const OFFSET: usize = 1; - const LOG_STEP: u32 = 2; - let coset = CanonicCoset::new(LOG_SIZE).coset(); - let col_eval = gen_is_step_with_offset::(LOG_SIZE, LOG_STEP, OFFSET); - let col_poly = col_eval.interpolate(); - let p = SECURE_FIELD_CIRCLE_GEN; - - let eval = eval_step_selector_with_offset(coset, OFFSET, LOG_STEP, p); - - assert_eq!(eval, col_poly.eval_at_point(p)); - } - - #[test] - fn eval_carry_quotient_col_works() { - const N_VARIABLES: usize = 5; - let mut rng = SmallRng::seed_from_u64(0); - let mle_eval_point = MleEvalPoint::::new(array::from_fn(|_| rng.gen())); - let col_eval = gen_carry_quotient_col(&mle_eval_point); - let twiddles = SimdBackend::precompute_twiddles(col_eval.domain.half_coset); - let col_poly = col_eval.interpolate_with_twiddles(&twiddles); - let p = SECURE_FIELD_CIRCLE_GEN; - - let eval = eval_carry_quotient_col(&mle_eval_point, p); - - assert_eq!(eval, col_poly.eval_at_point(p)); - } - - /// Generates a trace. - /// - /// Trace structure: - /// - /// ```text - /// ------------------------------------------------------------------------------------- - /// | MLE coeffs | EqEvals (basis) | MLE terms (prefix sum) | - /// ------------------------------------------------------------------------------------- - /// | c0 | c1 | c2 | c3 | c4 | c5 | c6 | c7 | c9 | c9 | c10 | c11 | - /// ------------------------------------------------------------------------------------- - /// ``` - fn gen_base_trace( - mle: &Mle, - eval_point: &[SecureField], - ) -> Vec> { - let mle_coeffs = mle.clone().into_evals(); - let eq_evals = SimdBackend::gen_eq_evals(eval_point, SecureField::one()).into_evals(); - let mle_terms = hadamard_product(&mle_coeffs, &eq_evals); - - let mle_coeff_cols = mle_coeffs.into_secure_column_by_coords().columns; - let eq_evals_cols = eq_evals.into_secure_column_by_coords().columns; - let mle_terms_cols = mle_terms.into_secure_column_by_coords().columns; - - let claim = mle.eval_at_point(eval_point); - let shift = claim / BaseField::from(mle.len()); - let packed_shifts = PackedSecureField::broadcast(shift).into_packed_m31s(); - let mut shifted_mle_terms_cols = mle_terms_cols.clone(); - zip(&mut shifted_mle_terms_cols, packed_shifts) - .for_each(|(col, shift)| col.data.iter_mut().for_each(|v| *v -= shift)); - let shifted_prefix_sum_cols = shifted_mle_terms_cols.map(inclusive_prefix_sum); - - let log_trace_domain_size = mle.n_variables() as u32; - let trace_domain = CanonicCoset::new(log_trace_domain_size).circle_domain(); - - chain![mle_coeff_cols, eq_evals_cols, shifted_prefix_sum_cols] - .map(|c| CircleEvaluation::new(trace_domain, c)) - .collect() - } - - /// Generates a trace. - /// - /// Trace structure: - /// - /// ```text - /// --------------------------------------------------------- - /// | Values | Values prefix sum | - /// --------------------------------------------------------- - /// | c0 | c1 | c2 | c3 | c4 | c5 | c6 | c7 | - /// --------------------------------------------------------- - /// ``` - fn gen_prefix_sum_trace( - values: Vec, - ) -> Vec> { - assert!(values.len().is_power_of_two()); - - let vals_circle_domain_order = coset_order_to_circle_domain_order(&values); - let mut vals_bit_rev_circle_domain_order = vals_circle_domain_order; - bit_reverse(&mut vals_bit_rev_circle_domain_order); - let vals_secure_col: SecureColumnByCoords = - vals_bit_rev_circle_domain_order.into_iter().collect(); - let vals_cols = vals_secure_col.columns; - - let cumulative_sum = values.iter().sum::(); - let cumulative_sum_shift = cumulative_sum / BaseField::from(values.len()); - let packed_cumulative_sum_shift = PackedSecureField::broadcast(cumulative_sum_shift); - let packed_shifts = packed_cumulative_sum_shift.into_packed_m31s(); - let mut shifted_cols = vals_cols.clone(); - zip(&mut shifted_cols, packed_shifts) - .for_each(|(col, packed_shift)| col.data.iter_mut().for_each(|v| *v -= packed_shift)); - let shifted_prefix_sum_cols = shifted_cols.map(inclusive_prefix_sum); - - let log_size = values.len().ilog2(); - let trace_domain = CanonicCoset::new(log_size).circle_domain(); - - chain![vals_cols, shifted_prefix_sum_cols] - .map(|c| CircleEvaluation::new(trace_domain, c)) - .collect() - } - - /// Returns the element-wise product of `a` and `b`. - fn hadamard_product( - a: &Col, - b: &Col, - ) -> Col { - assert_eq!(a.len(), b.len()); - SecureColumn { - data: zip_eq(&a.data, &b.data).map(|(&a, &b)| a * b).collect(), - length: a.len(), - } - } - - fn gen_constants_trace( - ) -> Vec> { - let log_size = N_VARIABLES as u32; - vec![gen_is_first(log_size)] - } -} diff --git a/Stwo_wrapper/crates/prover/src/examples/xor/gkr_lookups/mod.rs b/Stwo_wrapper/crates/prover/src/examples/xor/gkr_lookups/mod.rs deleted file mode 100644 index 6ee603e..0000000 --- a/Stwo_wrapper/crates/prover/src/examples/xor/gkr_lookups/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod accumulation; -pub mod mle_eval; diff --git a/Stwo_wrapper/crates/prover/src/examples/xor/mod.rs b/Stwo_wrapper/crates/prover/src/examples/xor/mod.rs deleted file mode 100644 index 34e702a..0000000 --- a/Stwo_wrapper/crates/prover/src/examples/xor/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod gkr_lookups; diff --git a/Stwo_wrapper/crates/prover/src/lib.rs b/Stwo_wrapper/crates/prover/src/lib.rs deleted file mode 100644 index 1e9c3be..0000000 --- a/Stwo_wrapper/crates/prover/src/lib.rs +++ /dev/null @@ -1,23 +0,0 @@ -#![allow(incomplete_features)] -#![feature( - array_chunks, - array_methods, - array_try_from_fn, - assert_matches, - exact_size_is_empty, - generic_const_exprs, - get_many_mut, - int_roundings, - is_sorted, - iter_array_chunks, - new_uninit, - portable_simd, - slice_first_last_chunk, - slice_flatten, - slice_group_by, - stdsimd -)] -pub mod constraint_framework; -pub mod core; -pub mod examples; -pub mod math; diff --git a/Stwo_wrapper/crates/prover/src/math/matrix.rs b/Stwo_wrapper/crates/prover/src/math/matrix.rs deleted file mode 100644 index 697f862..0000000 --- a/Stwo_wrapper/crates/prover/src/math/matrix.rs +++ /dev/null @@ -1,67 +0,0 @@ -use crate::core::fields::m31::BaseField; -use crate::core::fields::ExtensionOf; - -pub trait SquareMatrix, const N: usize> { - fn get_at(&self, i: usize, j: usize) -> F; - fn mul(&self, v: [F; N]) -> [F; N] { - (0..N) - .map(|i| { - (0..N) - .map(|j| self.get_at(i, j) * v[j]) - .fold(F::zero(), |acc, x| acc + x) - }) - .collect::>() - .try_into() - .unwrap() - } -} - -pub struct RowMajorMatrix, const N: usize> { - values: [[F; N]; N], -} - -impl, const N: usize> RowMajorMatrix { - pub fn new(values: Vec) -> Self { - assert_eq!(values.len(), N * N); - Self { - values: values - .chunks(N) - .map(|chunk| chunk.try_into().unwrap()) - .collect::>() - .try_into() - .unwrap(), - } - } -} - -impl, const N: usize> SquareMatrix for RowMajorMatrix { - fn get_at(&self, i: usize, j: usize) -> F { - self.values[i][j] - } -} - -#[cfg(test)] -mod tests { - use crate::core::fields::m31::M31; - use crate::m31; - use crate::math::matrix::{RowMajorMatrix, SquareMatrix}; - - #[test] - fn test_matrix_multiplication() { - let matrix = RowMajorMatrix::::new((0..9).map(|x| m31!(x + 1)).collect::>()); - let vector = (0..3) - .map(|x| m31!(x + 1)) - .collect::>() - .try_into() - .unwrap(); - let expected_result = [ - m31!(14), // 1 * 1 + 2 * 2 + 3 * 3 - m31!(32), // 4 * 1 + 5 * 2 + 6 * 3 - m31!(50), // 7 * 1 + 8 * 2 + 9 * 3 - ]; - - let result = matrix.mul(vector); - - assert_eq!(result, expected_result); - } -} diff --git a/Stwo_wrapper/crates/prover/src/math/mod.rs b/Stwo_wrapper/crates/prover/src/math/mod.rs deleted file mode 100644 index 42c0c38..0000000 --- a/Stwo_wrapper/crates/prover/src/math/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod matrix; -pub mod utils; diff --git a/Stwo_wrapper/crates/prover/src/math/utils.rs b/Stwo_wrapper/crates/prover/src/math/utils.rs deleted file mode 100644 index fd177fa..0000000 --- a/Stwo_wrapper/crates/prover/src/math/utils.rs +++ /dev/null @@ -1,24 +0,0 @@ -/// Returns s, t, g such that g = gcd(x,y), sx + ty = g. -pub fn egcd(x: isize, y: isize) -> (isize, isize, isize) { - if x == 0 { - return (0, 1, y); - } - let k = y / x; - let (s, t, g) = egcd(y % x, x); - (t - s * k, s, g) -} - -#[cfg(test)] -mod tests { - use crate::math::utils::egcd; - - #[test] - fn test_egcd() { - let pairs = [(17, 5, 1), (6, 4, 2), (7, 7, 7)]; - for (x, y, res) in pairs.into_iter() { - let (a, b, gcd) = egcd(x, y); - assert_eq!(gcd, res); - assert_eq!(a * x + b * y, gcd); - } - } -} diff --git a/Stwo_wrapper/poseidon_benchmark.sh b/Stwo_wrapper/poseidon_benchmark.sh deleted file mode 100755 index c796985..0000000 --- a/Stwo_wrapper/poseidon_benchmark.sh +++ /dev/null @@ -1,3 +0,0 @@ -LOG_N_INSTANCES=18 RUST_LOG_SPAN_EVENTS=enter,close RUST_LOG=info \ - RUSTFLAGS="-C target-cpu=native -C opt-level=3" \ - cargo test test_simd_poseidon_prove -- --nocapture diff --git a/Stwo_wrapper/resources/img/logo.png b/Stwo_wrapper/resources/img/logo.png deleted file mode 100644 index 07d6055..0000000 Binary files a/Stwo_wrapper/resources/img/logo.png and /dev/null differ diff --git a/Stwo_wrapper/rust-toolchain.toml b/Stwo_wrapper/rust-toolchain.toml deleted file mode 100644 index a0f1a93..0000000 --- a/Stwo_wrapper/rust-toolchain.toml +++ /dev/null @@ -1,2 +0,0 @@ -[toolchain] -channel = "nightly-2024-01-04" diff --git a/Stwo_wrapper/rustfmt.toml b/Stwo_wrapper/rustfmt.toml deleted file mode 100644 index 1214011..0000000 --- a/Stwo_wrapper/rustfmt.toml +++ /dev/null @@ -1,12 +0,0 @@ -# See: https://rust-lang.github.io/rustfmt -normalize_comments = true -use_field_init_shorthand = true - -# Unstable -comment_width = 100 -condense_wildcard_suffixes = true -format_code_in_doc_comments = true -group_imports = "StdExternalCrate" -imports_granularity = "Module" -unstable_features = true -wrap_comments = true diff --git a/Stwo_wrapper/scripts/bench.sh b/Stwo_wrapper/scripts/bench.sh deleted file mode 100755 index 416cfb3..0000000 --- a/Stwo_wrapper/scripts/bench.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/bash -# Can be used as a drop in replacement for `cargo bench`. -# For example, `./scripts/bench.sh` will run all benchmarks. -# or `./scripts/bench.sh M31` will run only the M31 benchmarks. -RUSTFLAGS="-Awarnings -C target-cpu=native -C target-feature=+avx512f -C opt-level=3" cargo bench $@ diff --git a/Stwo_wrapper/scripts/clippy.sh b/Stwo_wrapper/scripts/clippy.sh deleted file mode 100755 index 8361cd2..0000000 --- a/Stwo_wrapper/scripts/clippy.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/bash -cargo +nightly-2024-01-04 clippy "$@" --all-targets --all-features -- -D warnings -D future-incompatible \ - -D nonstandard-style -D rust-2018-idioms -D unused diff --git a/Stwo_wrapper/scripts/rust_fmt.sh b/Stwo_wrapper/scripts/rust_fmt.sh deleted file mode 100755 index e4223f9..0000000 --- a/Stwo_wrapper/scripts/rust_fmt.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/bash - -cargo +nightly-2024-01-04 fmt --all -- "$@" diff --git a/Stwo_wrapper/scripts/test_avx.sh b/Stwo_wrapper/scripts/test_avx.sh deleted file mode 100755 index d911a24..0000000 --- a/Stwo_wrapper/scripts/test_avx.sh +++ /dev/null @@ -1,4 +0,0 @@ -#!/bin/bash -# Can be used as a drop in replacement for `cargo test` with avx512f flag on. -# For example, `./scripts/test_avx.sh` will run all tests(not only avx). -RUSTFLAGS="-Awarnings -C target-cpu=native -C target-feature=+avx512f -C opt-level=2" cargo +nightly-2024-01-04 test "$@" diff --git a/Stwo_wrapper/verifier_script.ipynb b/Stwo_wrapper/verifier_script.ipynb deleted file mode 100644 index 433afa9..0000000 --- a/Stwo_wrapper/verifier_script.ipynb +++ /dev/null @@ -1,581 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "43f7fdc0", - "metadata": {}, - "outputs": [], - "source": [ - "import json\n", - "\n", - "with open(\"crates/prover/proof.json\", \"r\") as f:\n", - " proof = json.load(f)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "88f232b5", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "12ef3767", - "metadata": {}, - "outputs": [], - "source": [ - "F = FiniteField(52435875175126190479447740508185965837690552500527637822603658699938581184513)\n", - "\n", - "poseidon_comp_consts = [50207570499218320245539736680169582180207201335688461025883902752909290481781,\n", - "24448666467656506447555018649749346340705294023832615387641453784702583464707,\n", - "34092944507611308604157957266676007619644244199372265837364557849561670729974,\n", - "46954129210702959446093971191783182601726081775951103310666314834569091037713,\n", - "38612156878839717097806285947575477749087608521505464809942918879152074545066,\n", - "19752610610343814834081989345964253902282700341539483876504601969121084774539,\n", - "46567545048462867923299713424766325689670511126407629551256255807498976196546,\n", - "9520793415506326549109545537894287560752519598132096386048093015534488804808,\n", - "22814234098357034097599682726494820560934925862581927123816510593532324971186,\n", - "3277621627834606517208177071759088097855048183641615082769528872043050020787,\n", - "29230456498980145088774069819561206654397510279226264474986155631775387918911,\n", - "19087113294497892618475669593723876605785307026981218038380435259594863105240,\n", - "39932371919358015185769877859035474336011770016475087638554815294278664040916,\n", - "17645770319151120318035258350885823104235488352935695302274836429012504407725,\n", - "17990728141399065004015538797609951295983853332644474801890158217822768128628,\n", - "12607949331462269429981198199999740921418125994747028428126661151190418292729,\n", - "33067617079394435172767143524489677593390850035349407507374659268468278200906,\n", - "10025233623562179533044093426455032352895184661359005809314430689113735312874,\n", - "20398677688057466110325934731430812468657996794663167456321709689030080949228,\n", - "32085671199853825909918260218834827339732598508827083525700252644622592932757,\n", - "36451986593067827349794003109666944974266236856145879921902940325507228739480,\n", - "51835224419566813714481533481210630888564327175625175437244377303858990291964,\n", - "1944662263588038198375346521900053780907777056656211622999059135594196413076,\n", - "12995068374816903282074967132431954020410301768622808407703775963080983755183,\n", - "13278128079226679628648689279705910775020794457648431336050464485837924986341,\n", - "39207195481789228835625472428521288347432218258431761869689775532020546099642,\n", - "21081768833381902942114733002158882075348844281359283013642620389621494952015,\n", - "20751788049060260683191405008569080723662271828149227137187075968560831545739,\n", - "20820291785607398388900832350860967875629907105847554413318238165275470374689,\n", - "6971878585215744613467847324629115462668098071102846520957717612260531709386,\n", - "42421164250058173810994728364144776180689735894673627964404703973460802099146,\n", - "32890116643831560295329417521056875595733120141391587236744387068135440602102,\n", - "42670005614507618780436482775021159957307712089310941922452133588875084445464,\n", - "21120353743307986506720883740380468652053382764895882204680310593048134053982,\n", - "7853308243263055176258751393326645428041138029306706980470113526802326214700,\n", - "17545076036297840030021082424260289805456380863517895917265467158332801090765,\n", - "29526223376722400691172584788126610514669516909826971155598997488361793726636,\n", - "48421712782536172546302502401679048379568171245541707202282458591545347755349,\n", - "10740853637774754893036062076749871837371049036966225040269105665447180116170,\n", - "34042041521558704677804677569712674569738576001717295340556848855085089618161,\n", - "24290796201833228559129233924595614281891670608675107544294264860003803501509,\n", - "26722678647461522072509896114724736555938247563993442152746954157222882824350,\n", - "20252491387019425681551488261397157776479297799360691728406809731508542196845,\n", - "50322025264206689090790987370440439179141270613911973034521438238687587958097,\n", - "17070806525931584028449131949070191143344166668070820337429561524629464200550,\n", - "25856554324149146992239414502939942208580094928192925471532421030223074525051,\n", - "17714998974036855356530338446243137421735047395517260588250413348153258772076,\n", - "44833315250334176776685835079382312848180252180173884969157994737319426976437,\n", - "35603718839327251012037553292043899153393807438387129505923567878785822738162,\n", - "20515196301761603016197694845695272699608637099106794944737311528118558777570,\n", - "10100400556460905874275078234698187530913105549037797180493988678937053918124,\n", - "29943022708270799252522211109308629054849337552699067311814388215768905671554,\n", - "33400164627534996188947689774080657908147988421361870074239537729877153299092,\n", - "45574161704098228712016716221086232277248798839906622903502141601878895917316,\n", - "40623265267364613450776577487319920007897396936924051398790906883872334022964,\n", - "37929176440858430683261948300797278761072096845318183419284347376614069989808,\n", - "12242010394227909997626655999345208835040087302065045201635069094289920778463,\n", - "38947272924417356803622776795797899233194116520680026665045628837194239730633,\n", - "6838505804652359252670794375725267665530548946030641535297433541475260948424,\n", - "21345718918993308853491352363460625447157796362108157527364130872100101143328,\n", - "26397988737034501095129796920971941795766209722106383463197090306632188634870,\n", - "47092791129593573928369881528796435131623991381197863072979392492232678100884,\n", - "36850972241154890671857874025605504779963735054128436776319531005864791472123,\n", - "27893799443241349360688137159923920340185830261519093384488134540544971987330,\n", - "34031071010517479317003393843135868322188010660871691856659878788331169912272,\n", - "3102550735908358465878301372253437950829524988677083749179431098369388780259,\n", - "2963742902601529003553690631564645593518709846059084207036841793643477514707,\n", - "34538583661636382515652368664945657625216404085453317149263146639486246251503,\n", - "49179786922858759927440465310900376749726765337268308911471491527044937447403,\n", - "31668552784983283483593666924944066737680315058069542500069213700768949573692,\n", - "47303630019147536941220901582952982856517915740884282232588733470564849742080,\n", - "41561182787858915334837446901194440640033856888621022207410120224293681204923,\n", - "40208795410444394963490428737133513683110766973508056822474493355065333491217,\n", - "24620569969402072776192280888011017497854992833864712509770555543278833718751,\n", - "31418811028946653724823259636547682581071379929451162101915628592655152015310,\n", - "25964807298150242099204032696543021731332498792173212422070959505270506288817,\n", - "31766013031271106581980804902159064978010553325475976472264348555438361464655,\n", - "15107529391758643095716794813038523751713309080738989300826699946985294497278,\n", - "26149402682269665088314773514719203730233986608723938665192802061570851149320,\n", - "35053126320072620250684851851709987160095640397875384355477447570643983599564]\n", - "\n", - "def mix(state):\n", - " state[0] = state[0] + state[1] + state[2];\n", - " state[1] = state[0] + state[1];\n", - " state[2] = state[0] + state[2];\n", - " return state\n", - "\n", - "def round_comp(state, idx, full):\n", - " if full:\n", - " state[0] += poseidon_comp_consts[idx]\n", - " state[1] += poseidon_comp_consts[idx + 1]\n", - " state[2] += poseidon_comp_consts[idx + 2]\n", - " # Optimize multiplication\n", - " state[0] = state[0] * state[0] * state[0] * state[0] * state[0]\n", - " state[1] = state[1] * state[1] * state[1] * state[1] * state[1]\n", - " state[2] = state[2] * state[2] * state[2] * state[2] * state[2]\n", - " else:\n", - " state[0] += poseidon_comp_consts[idx]\n", - " state[2] = state[2] * state[2] * state[2] * state[2] * state[2]\n", - " state = mix(state)\n", - " return state\n", - "\n", - "def poseidon_permute_comp_bls(state):\n", - " idx = 0;\n", - " state = mix(state);\n", - "\n", - " # Full rounds\n", - " for i in range(4):\n", - " state = round_comp(state, idx, true);\n", - " idx += 3;\n", - "\n", - " # Partial rounds\n", - " for i in range(56):\n", - " state = round_comp(state, idx, false);\n", - " idx += 1;\n", - "\n", - " # Full rounds\n", - " for i in range(4):\n", - " state= round_comp(state, idx, true);\n", - " idx += 3;\n", - " return state\n", - "\n", - " \n", - "def poseidon_hash_bls(x, y):\n", - " state = [F(x), F(y), F(0)];\n", - " poseidon_permute_comp_bls(state);\n", - " return state[0] + F(x)\n", - "\n", - "\n", - "def poseidon_hash_many_bls(msgs):\n", - " state = [F(0), F(0), F(0)];\n", - " for i in range(len(msgs)//2):\n", - " state[0] += F(msgs[2*i])\n", - " state[1] += F(msgs[2*i + 1])\n", - " state = poseidon_permute_comp_bls(state)\n", - " if len(msgs) % 2 == 1:\n", - " state[0] += F(msgs[len(msgs)-1])\n", - " state[len(msgs)%2] += F(1);\n", - " state = poseidon_permute_comp_bls(state)\n", - " return state[0]" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "67a09953", - "metadata": {}, - "outputs": [], - "source": [ - "def draw_felt252(digest, n_sent):\n", - " res = poseidon_hash_bls(digest, F(n_sent));\n", - " return res\n", - "\n", - "def draw_base_felts(digest, n_sent):\n", - " shift = 1 << 31\n", - " cur = int(draw_felt252(digest, n_sent))\n", - " n_sent += 1\n", - " u32s = []\n", - " quotient = 1\n", - " while(quotient != 0):\n", - " quotient = cur // shift\n", - " if quotient != 0:\n", - " remainder = cur % shift\n", - " cur = quotient\n", - " u32s.append(M31(remainder))\n", - " return u32s\n", - "\n", - "def draw_random_bytes(digest, n_sent):\n", - " shift = 1 << 8\n", - " cur = int(draw_felt252(digest, n_sent))\n", - " byte = []\n", - " for i in range(31):\n", - " quotient = cur // shift\n", - " remainder = cur % shift\n", - " cur = quotient\n", - " byte.append(remainder)\n", - " return byte\n", - "\n", - "def hash_node(has_children, left, right, column_values):\n", - " n_column_blocks = ceil(len(column_values) / 8.0);\n", - " values = []\n", - " if has_children:\n", - " values.append(left)\n", - " values.append(right)\n", - " padding_length = 8 * n_column_blocks - len(column_values)\n", - " padded_values = column_values + [F(0) for i in range(padding_length)]\n", - " for i in range(int(len(padded_values) / 8)):\n", - " word = F(0)\n", - " for j in range(8):\n", - " word = word * F(2**31) + F(padded_values[i*8+j])\n", - " values.append(word)\n", - " return poseidon_hash_many_bls(values)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "1bce81d8", - "metadata": {}, - "outputs": [], - "source": [ - "M31 = FiniteField(2**31 - 1)\n", - "Poly_1. = PolynomialRing(M31)\n", - "poly_1 = x**2+1\n", - "CM31. = M31.extension(poly_1)\n", - "CM31\n", - "\n", - "Poly_2. = PolynomialRing(CM31)\n", - "poly_2 = y**2 - 2 - i\n", - "QM31. = CM31.extension(poly_2)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "16aa579c", - "metadata": {}, - "outputs": [], - "source": [ - "# A class to represent points on a circle using fields.\n", - "class CirclePoint:\n", - " def __init__(self, x, y):\n", - " self.x = x\n", - " self.y = y\n", - "\n", - " @staticmethod\n", - " def zero():\n", - " \"\"\"Returns the identity element of the circle.\"\"\"\n", - " return CirclePoint(1, 0)\n", - "\n", - " def double(self):\n", - " \"\"\"Returns the point after doubling.\"\"\"\n", - " return self + self\n", - "\n", - " @staticmethod\n", - " def double_x(x):\n", - " \"\"\"Applies the x-coordinate doubling map.\"\"\"\n", - " sx = x**2\n", - " return sx + sx - 1\n", - "\n", - " def log_order(self):\n", - " \"\"\"Returns the log order of the point (as a power of 2).\"\"\"\n", - " res = 0\n", - " cur = self.x\n", - " while cur != 1:\n", - " cur = self.double_x(cur)\n", - " res += 1\n", - " return res\n", - "\n", - " def mul(self, scalar):\n", - " \"\"\"Multiplies the point by a scalar.\"\"\"\n", - " res = CirclePoint.zero()\n", - " cur = self\n", - " while scalar > 0:\n", - " if scalar & 1 == 1:\n", - " res = res + cur\n", - " cur = cur + cur\n", - " scalar = int(scalar / 2)\n", - " return res\n", - "\n", - " def repeated_double(self, n):\n", - " \"\"\"Returns the point after repeated doubling.\"\"\"\n", - " res = self\n", - " for _ in range(n):\n", - " res = res.double()\n", - " return res\n", - "\n", - " def conjugate(self):\n", - " \"\"\"Returns the conjugate of the point.\"\"\"\n", - " return CirclePoint(self.x, -self.y)\n", - "\n", - " def antipode(self):\n", - " \"\"\"Returns the antipode of the point.\"\"\"\n", - " return CirclePoint(-self.x, -self.y)\n", - "\n", - " def mul_signed(self, off):\n", - " \"\"\"Multiplies the point by a signed scalar.\"\"\"\n", - " if off > 0:\n", - " return self.mul(off)\n", - " else:\n", - " return self.conjugate().mul(-off)\n", - "\n", - " def __add__(self, other):\n", - " \"\"\"Adds two circle points.\"\"\"\n", - " return CirclePoint(self.x * other.x - self.y * other.y, self.x * other.y + self.y * other.x)\n", - "\n", - " def __repr__(self):\n", - " return f\"CirclePoint(x={self.x}, y={self.y})\"" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "d9371b02", - "metadata": {}, - "outputs": [], - "source": [ - "# First Fiat-Shamir\n", - "digest = poseidon_hash_bls(0,proof[\"commitments\"][0])\n", - "\n", - "# Draw first random coeff\n", - "random_coeff = QM31([draw_base_felts(digest,0)[:2],draw_base_felts(digest,0)[2:4]])\n", - "\n", - "# Second Fiat-Shamir\n", - "digest = poseidon_hash_bls(digest,proof[\"commitments\"][1])\n", - "\n", - "# Draw OODS sample point\n", - "t = QM31([draw_base_felts(digest,0)[:2],draw_base_felts(digest,0)[2:4]])\n", - "t_square = t**2\n", - "one_plus_tsquared_inv = (t_square + QM31(1)) ** (-1)\n", - "x = (1 - t_square) * one_plus_tsquared_inv\n", - "y = 2*t*one_plus_tsquared_inv\n", - "oods_point = CirclePoint(x,y)\n", - "\n", - "# Derive sample points for now offset is always 0 so sampled points are all equals oods_point\n", - "\n", - "# Load sampled_values_1 and verify that they are in QM31\n", - "\n", - "# Then we verify the composition polynomial evaluation:\n", - "point = oods_point\n", - "mask_values = proof[\"sampled_values_0\"] + proof[\"sampled_values_1\"]\n", - "\n", - "\n", - "\n", - "\n", - "# Evaluate random point on the vanishing polynomial of the coset\n", - "evaluation = point.x\n", - "for i in range(5):\n", - " evaluation = CirclePoint.double_x(evaluation)\n", - "evaluation_inverse = evaluation ** (-1)\n", - "\n", - "# Compute evaluation using the mask values\n", - "accumulator = QM31(0)\n", - "a = QM31([mask_values[0][:2],mask_values[0][2:4]])\n", - "b = QM31([mask_values[1][:2],mask_values[1][2:4]])\n", - "for i in range(len(proof[\"sampled_values_0\"])-2):\n", - " c = QM31([mask_values[2+i][:2],mask_values[2+i][2:4]])\n", - " accumulator = (c - (a**2 + b**2))*evaluation_inverse + accumulator * random_coeff\n", - " a = b\n", - " b = c\n", - "\n", - "expected_value = QM31([proof[\"sampled_values_1\"][0][:2],proof[\"sampled_values_1\"][0][2:4]]) + QM31([proof[\"sampled_values_1\"][1][:2],proof[\"sampled_values_1\"][1][2:4]]) * QM31([[0,1],[0,0]]) + QM31([proof[\"sampled_values_1\"][2][:2],proof[\"sampled_values_1\"][2][2:4]]) * QM31([[0,0],[1,0]]) + QM31([proof[\"sampled_values_1\"][3][:2],proof[\"sampled_values_1\"][3][2:4]]) * QM31([[0,0],[0,1]])\n", - "\n", - "assert(expected_value == accumulator)\n", - "\n", - "\n", - "####FRI part\n", - "# Fiat-Shamir\n", - "shift = F(1 << 31)\n", - "res = [digest]\n", - "for i in range(len(mask_values) / 2):\n", - " cur = F(0)\n", - " for j in range(2):\n", - " for k in range(4):\n", - " cur = cur * shift + F(mask_values[2*i+j][k])\n", - " res.append(cur)\n", - "digest = poseidon_hash_many_bls(res)\n", - "random_coeff = QM31([draw_base_felts(digest,0)[:2],draw_base_felts(digest,0)[2:4]])\n", - "\n", - "# Verify commitment stage\n", - "circle_poly_alpha = QM31([draw_base_felts(digest,1)[:2],draw_base_felts(digest,1)[2:4]])\n", - "folding_alpha = []\n", - "for i in range(6):\n", - " digest = poseidon_hash_bls(digest,proof[\"inner_commitment_\"+str(i)])\n", - " folding_alpha.append(QM31([draw_base_felts(digest,0)[:2],draw_base_felts(digest,0)[2:4]]))\n", - "last_layer_poly = QM31([proof[\"coeffs\"][:2],proof[\"coeffs\"][2:4]])\n", - "\n", - "#Check proof of work\n", - "res = [digest]\n", - "cur = F(0)\n", - "for i in range(4):\n", - " cur = cur * shift + F(proof[\"coeffs\"][i])\n", - "res.append(cur)\n", - "digest = poseidon_hash_many_bls(res)\n", - "digest = poseidon_hash_bls(digest, proof[\"proof of work\"])\n", - "# The first 5 bits following the first 1 must be 0 for proof of work\n", - "for i in range(5):\n", - " assert(bin(digest)[3+i] == '0')\n", - "\n", - "# Compute openings positions\n", - "querries = []\n", - "max_querry = (1<<8)-1\n", - "random_bytes = draw_random_bytes(digest,0)\n", - "# The number of required querries is 3 so we need 3*8 bytes wich is enough with the 31 bytes of random_bytes\n", - "for i in range(3):\n", - " querry_bits = int().from_bytes(random_bytes[i*4:i*4+4],byteorder=\"little\")\n", - " querries.append(querry_bits & max_querry)\n", - "positions = []\n", - "# For each collumn\n", - "for i in range(2):\n", - " # For each querry\n", - " for j in range(3):\n", - " positions.append(querries[j] >> (1+i))\n", - "positions_sav = positions.copy()\n", - "\n", - "\n", - "#Verify decommitments (this part must be simplified repeting duplicated inputs)\n", - "# First tree\n", - "nodes = {}\n", - "for i in range(6):\n", - " leaf = []\n", - " for j in range(100):\n", - " leaf.append(M31(proof[\"queried_values_0\"][j][i]))\n", - " nodes[\"level_0_node_\"+str(positions[i//2] // 2 * 2 + i % 2) ] = hash_node(false,0,0,leaf)\n", - "counter = 0\n", - "for level in range(7):\n", - " for i in range(3):\n", - " if \"level_\"+str(level)+\"_node_\"+str(positions[len(positions)-3] * 2) in nodes:\n", - " lhs = nodes[\"level_\"+str(level)+\"_node_\"+str(positions[len(positions)-3] * 2)]\n", - " else:\n", - " lhs = F(proof[\"decommitment_0\"][counter])\n", - " counter += 1\n", - " if \"level_\"+str(level)+\"_node_\"+str(positions[len(positions) - 3] * 2 + 1) in nodes:\n", - " rhs = nodes[\"level_\"+str(level)+\"_node_\"+str(positions[len(positions) - 3] * 2 + 1)]\n", - " else:\n", - " rhs = F(proof[\"decommitment_0\"][counter])\n", - " counter += 1\n", - " nodes[\"level_\"+str(level+1)+\"_node_\"+str(positions[len(positions)-3])] = hash_node(true,lhs,rhs,[])\n", - " positions.append(positions[len(positions)-3] // 2)\n", - "assert(nodes[\"level_7_node_0\"] == F(proof[\"commitments\"][0]))\n", - "\n", - "#Second tree\n", - "positions = positions_sav\n", - "for i in range(len(positions)):\n", - " positions[i] = positions[i] * 2\n", - "nodes = {}\n", - "for i in range(6):\n", - " leaf = []\n", - " for j in range(4):\n", - " leaf.append(M31(proof[\"queried_values_1\"][j][i]))\n", - " nodes[\"level_0_node_\"+str(positions[i//2] // 2 * 2 + i % 2) ] = hash_node(false,0,0,leaf)\n", - "counter = 0\n", - "for level in range(8):\n", - " for i in range(3):\n", - " if \"level_\"+str(level)+\"_node_\"+str(positions[len(positions)-3] * 2) in nodes:\n", - " lhs = nodes[\"level_\"+str(level)+\"_node_\"+str(positions[len(positions)-3] * 2)]\n", - " else:\n", - " lhs = F(proof[\"decommitment_1\"][counter])\n", - " counter += 1\n", - " if \"level_\"+str(level)+\"_node_\"+str(positions[len(positions) - 3] * 2 + 1) in nodes:\n", - " rhs = nodes[\"level_\"+str(level)+\"_node_\"+str(positions[len(positions) - 3] * 2 + 1)]\n", - " else:\n", - " rhs = F(proof[\"decommitment_1\"][counter])\n", - " counter += 1\n", - " nodes[\"level_\"+str(level+1)+\"_node_\"+str(positions[len(positions)-3])] = hash_node(true,lhs,rhs,[])\n", - " if level == 0:\n", - " nodes[\"level_1_node_\"+str(positions[len(positions)-3] + 1)] = hash_node(true,nodes[\"level_0_node_\"+str((positions[len(positions)-3] + 1)* 2)],nodes[\"level_0_node_\"+str((positions[len(positions)-3] + 1)* 2 + 1)],[])\n", - " positions.append(positions[len(positions)-3] // 2)\n", - "assert(nodes[\"level_8_node_0\"] == F(proof[\"commitments\"][1]))\n", - "\n", - "\n", - "# Check querries answer we have 104 samples each equal to x = oods_point and f(x) = sampled_value_i\n", - "# For each column:\n", - "for i in range(2):\n", - " evaluations = []\n", - " queried_values_per_column = proof[\"queried_values_\"+str(i)]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "56e4c016", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6b89599c", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "e486cb3a", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[['316772341',\n", - " '1526280133',\n", - " '663010112',\n", - " '224983897',\n", - " '510598760',\n", - " '1109503351'],\n", - " ['754832207',\n", - " '435790299',\n", - " '883623752',\n", - " '553207508',\n", - " '154784232',\n", - " '199176676'],\n", - " ['689603315',\n", - " '1763523007',\n", - " '1720552945',\n", - " '1983603154',\n", - " '367841669',\n", - " '319325418'],\n", - " ['1290247052',\n", - " '1120744584',\n", - " '193500372',\n", - " '294491115',\n", - " '951360807',\n", - " '891034447']]" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "queried_values_per_column" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "85dd52b3", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "SageMath 9.5", - "language": "sage", - "name": "sagemath" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -}