mirror of
https://github.com/logos-blockchain/logos-blockchain-pocs.git
synced 2026-01-05 14:43:08 +00:00
Merge pull request #56 from logos-co/origin/Circom_PoL
Merge the two circom branches
This commit is contained in:
commit
6299aaa843
22
Stwo_wrapper/Cargo.toml
Normal file
22
Stwo_wrapper/Cargo.toml
Normal file
@ -0,0 +1,22 @@
|
||||
[workspace]
|
||||
members = ["crates/prover"]
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.1.1"
|
||||
edition = "2021"
|
||||
|
||||
[workspace.dependencies]
|
||||
blake2 = "0.10.6"
|
||||
blake3 = "1.5.0"
|
||||
educe = "0.5.0"
|
||||
hex = "0.4.3"
|
||||
itertools = "0.12.0"
|
||||
num-traits = "0.2.17"
|
||||
thiserror = "1.0.56"
|
||||
bytemuck = "1.14.3"
|
||||
tracing = "0.1.40"
|
||||
|
||||
[profile.bench]
|
||||
codegen-units = 1
|
||||
lto = true
|
||||
201
Stwo_wrapper/LICENSE
Normal file
201
Stwo_wrapper/LICENSE
Normal file
@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2024 StarkWare Industries Ltd.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
62
Stwo_wrapper/README.md
Normal file
62
Stwo_wrapper/README.md
Normal file
@ -0,0 +1,62 @@
|
||||
<div align="center">
|
||||
|
||||

|
||||
|
||||
<a href="https://github.com/starkware-libs/stwo/actions/workflows/ci.yaml"><img alt="GitHub Workflow Status (with event)" src="https://img.shields.io/github/actions/workflow/status/starkware-libs/stwo/ci.yaml?style=for-the-badge" height=30></a>
|
||||
<a href="https://codecov.io/gh/starkware-libs/stwo" >
|
||||
<img src="https://img.shields.io/codecov/c/github/starkware-libs/stwo?style=for-the-badge&logo=codecov" height=30/>
|
||||
</a>
|
||||
<a href="https://github.com/starkware-libs/stwo/blob/main/LICENSE"><img src="https://img.shields.io/github/license/starkware-libs/stwo.svg?style=for-the-badge" alt="Project license" height="30"></a>
|
||||
<a href="https://starkware.co/"><img src="https://img.shields.io/badge/By StarkWare-29296E.svg?&style=for-the-badge&logo=" alt="StarkWare" height="30"></a>
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
<h3>
|
||||
<a href="https://eprint.iacr.org/2024/278">
|
||||
Paper
|
||||
</a>
|
||||
<span> | </span>
|
||||
<a href="https://github.com/starkware-libs/stwo">
|
||||
Documentation
|
||||
</a>
|
||||
<span> | </span>
|
||||
<a href="https://starkware-libs.github.io/stwo/dev/bench/index.html">
|
||||
Benchmarks
|
||||
</a>
|
||||
</h3>
|
||||
</div>
|
||||
|
||||
# Stwo
|
||||
|
||||
## 🌟 About
|
||||
|
||||
Stwo is a next generation implementation of a [CSTARK](https://eprint.iacr.org/2024/278) prover and verifier, written in Rust 🦀.
|
||||
|
||||
> **Stwo is a work in progress.**
|
||||
>
|
||||
> It is not recommended to use it in a production setting yet.
|
||||
|
||||
## 🚀 Key Features
|
||||
|
||||
- **Circle STARKs:** Based on the latest cryptographic research and innovations in the ZK field.
|
||||
- **High performance:** Stwo is designed to be extremely fast and efficient.
|
||||
- **Flexible:** Adaptable for various validity proof applications.
|
||||
|
||||
## 📊 Benchmarks
|
||||
|
||||
Run `poseidon_benchmark.sh` to run a single-threaded poseidon2 hash proof benchmark.
|
||||
|
||||
Further benchmarks can be run using `cargo bench`.
|
||||
|
||||
Visual representation of benchmarks can be found [here](https://starkware-libs.github.io/stwo/dev/bench/index.html).
|
||||
|
||||
## 📜 License
|
||||
|
||||
This project is licensed under the **Apache 2.0 license**.
|
||||
|
||||
See [LICENSE](LICENSE) for more information.
|
||||
|
||||
<!-- markdownlint-restore -->
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
<!-- ALL-CONTRIBUTORS-LIST:END -->
|
||||
0
Stwo_wrapper/WORKSPACE
Normal file
0
Stwo_wrapper/WORKSPACE
Normal file
110
Stwo_wrapper/crates/prover/Cargo.toml
Normal file
110
Stwo_wrapper/crates/prover/Cargo.toml
Normal file
@ -0,0 +1,110 @@
|
||||
[package]
|
||||
name = "stwo-prover"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[features]
|
||||
parallel = ["rayon"]
|
||||
slow-tests = []
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
blake2.workspace = true
|
||||
blake3.workspace = true
|
||||
bytemuck = { workspace = true, features = ["derive", "extern_crate_alloc"] }
|
||||
cfg-if = "1.0.0"
|
||||
downcast-rs = "1.2"
|
||||
educe.workspace = true
|
||||
hex.workspace = true
|
||||
itertools.workspace = true
|
||||
num-traits.workspace = true
|
||||
rand = { version = "0.8.5", default-features = false, features = ["small_rng"] }
|
||||
starknet-crypto = "0.6.2"
|
||||
starknet-ff = "0.3.7"
|
||||
ark-bls12-381 = "0.4.0"
|
||||
ark-ff = "0.4.0"
|
||||
thiserror.workspace = true
|
||||
tracing.workspace = true
|
||||
rayon = { version = "1.10.0", optional = true }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
crypto-bigint = "0.5.5"
|
||||
ark-serialize = "0.4.0"
|
||||
serde_json = "1.0.116"
|
||||
|
||||
[dev-dependencies]
|
||||
aligned = "0.4.2"
|
||||
test-log = { version = "0.2.15", features = ["trace"] }
|
||||
tracing-subscriber = "0.3.18"
|
||||
[target.'cfg(all(target_family = "wasm", not(target_os = "wasi")))'.dev-dependencies]
|
||||
wasm-bindgen-test = "0.3.43"
|
||||
|
||||
[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies.criterion]
|
||||
features = ["html_reports"]
|
||||
version = "0.5.1"
|
||||
|
||||
# Default features cause compile error:
|
||||
# "Rayon cannot be used when targeting wasi32. Try disabling default features."
|
||||
[target.'cfg(target_arch = "wasm32")'.dev-dependencies.criterion]
|
||||
default-features = false
|
||||
features = ["html_reports"]
|
||||
version = "0.5.1"
|
||||
|
||||
[lib]
|
||||
bench = false
|
||||
crate-type = ["cdylib", "lib"]
|
||||
|
||||
[lints.rust]
|
||||
warnings = "deny"
|
||||
future-incompatible = "deny"
|
||||
nonstandard-style = "deny"
|
||||
rust-2018-idioms = "deny"
|
||||
unused = "deny"
|
||||
|
||||
[[bench]]
|
||||
harness = false
|
||||
name = "bit_rev"
|
||||
|
||||
[[bench]]
|
||||
harness = false
|
||||
name = "eval_at_point"
|
||||
|
||||
[[bench]]
|
||||
harness = false
|
||||
name = "fft"
|
||||
|
||||
[[bench]]
|
||||
harness = false
|
||||
name = "field"
|
||||
|
||||
[[bench]]
|
||||
harness = false
|
||||
name = "fri"
|
||||
|
||||
[[bench]]
|
||||
harness = false
|
||||
name = "lookups"
|
||||
|
||||
[[bench]]
|
||||
harness = false
|
||||
name = "matrix"
|
||||
|
||||
[[bench]]
|
||||
harness = false
|
||||
name = "merkle"
|
||||
|
||||
[[bench]]
|
||||
harness = false
|
||||
name = "poseidon"
|
||||
|
||||
[[bench]]
|
||||
harness = false
|
||||
name = "prefix_sum"
|
||||
|
||||
[[bench]]
|
||||
harness = false
|
||||
name = "quotients"
|
||||
|
||||
[[bench]]
|
||||
harness = false
|
||||
name = "pcs"
|
||||
2
Stwo_wrapper/crates/prover/benches/README.md
Normal file
2
Stwo_wrapper/crates/prover/benches/README.md
Normal file
@ -0,0 +1,2 @@
|
||||
dev benchmark results can be seen at
|
||||
https://starkware-libs.github.io/stwo/dev/bench/index.html
|
||||
39
Stwo_wrapper/crates/prover/benches/bit_rev.rs
Normal file
39
Stwo_wrapper/crates/prover/benches/bit_rev.rs
Normal file
@ -0,0 +1,39 @@
|
||||
#![feature(iter_array_chunks)]
|
||||
|
||||
use criterion::{criterion_group, criterion_main, BatchSize, Criterion};
|
||||
use itertools::Itertools;
|
||||
use stwo_prover::core::fields::m31::BaseField;
|
||||
|
||||
pub fn cpu_bit_rev(c: &mut Criterion) {
|
||||
use stwo_prover::core::utils::bit_reverse;
|
||||
// TODO(andrew): Consider using same size for all.
|
||||
const SIZE: usize = 1 << 24;
|
||||
let data = (0..SIZE).map(BaseField::from).collect_vec();
|
||||
c.bench_function("cpu bit_rev 24bit", |b| {
|
||||
b.iter_batched(
|
||||
|| data.clone(),
|
||||
|mut data| bit_reverse(&mut data),
|
||||
BatchSize::LargeInput,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
pub fn simd_bit_rev(c: &mut Criterion) {
|
||||
use stwo_prover::core::backend::simd::bit_reverse::bit_reverse_m31;
|
||||
use stwo_prover::core::backend::simd::column::BaseColumn;
|
||||
const SIZE: usize = 1 << 26;
|
||||
let data = (0..SIZE).map(BaseField::from).collect::<BaseColumn>();
|
||||
c.bench_function("simd bit_rev 26bit", |b| {
|
||||
b.iter_batched(
|
||||
|| data.data.clone(),
|
||||
|mut data| bit_reverse_m31(&mut data),
|
||||
BatchSize::LargeInput,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
name = bit_rev;
|
||||
config = Criterion::default().sample_size(10);
|
||||
targets = simd_bit_rev, cpu_bit_rev);
|
||||
criterion_main!(bit_rev);
|
||||
35
Stwo_wrapper/crates/prover/benches/eval_at_point.rs
Normal file
35
Stwo_wrapper/crates/prover/benches/eval_at_point.rs
Normal file
@ -0,0 +1,35 @@
|
||||
use criterion::{black_box, criterion_group, criterion_main, Criterion};
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use stwo_prover::core::backend::cpu::CpuBackend;
|
||||
use stwo_prover::core::backend::simd::SimdBackend;
|
||||
use stwo_prover::core::circle::CirclePoint;
|
||||
use stwo_prover::core::fields::m31::BaseField;
|
||||
use stwo_prover::core::poly::circle::{CirclePoly, PolyOps};
|
||||
|
||||
const LOG_SIZE: u32 = 20;
|
||||
|
||||
fn bench_eval_at_secure_point<B: PolyOps>(c: &mut Criterion, id: &str) {
|
||||
let poly = CirclePoly::new((0..1 << LOG_SIZE).map(BaseField::from).collect());
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let x = rng.gen();
|
||||
let y = rng.gen();
|
||||
let point = CirclePoint { x, y };
|
||||
c.bench_function(
|
||||
&format!("{id} eval_at_secure_field_point 2^{LOG_SIZE}"),
|
||||
|b| {
|
||||
b.iter(|| B::eval_at_point(black_box(&poly), black_box(point)));
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
fn eval_at_secure_point_benches(c: &mut Criterion) {
|
||||
bench_eval_at_secure_point::<SimdBackend>(c, "simd");
|
||||
bench_eval_at_secure_point::<CpuBackend>(c, "cpu");
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
name = benches;
|
||||
config = Criterion::default().sample_size(10);
|
||||
targets = eval_at_secure_point_benches);
|
||||
criterion_main!(benches);
|
||||
131
Stwo_wrapper/crates/prover/benches/fft.rs
Normal file
131
Stwo_wrapper/crates/prover/benches/fft.rs
Normal file
@ -0,0 +1,131 @@
|
||||
#![feature(iter_array_chunks)]
|
||||
|
||||
use std::hint::black_box;
|
||||
use std::mem::{size_of_val, transmute};
|
||||
|
||||
use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion, Throughput};
|
||||
use itertools::Itertools;
|
||||
use stwo_prover::core::backend::simd::column::BaseColumn;
|
||||
use stwo_prover::core::backend::simd::fft::ifft::{
|
||||
get_itwiddle_dbls, ifft, ifft3_loop, ifft_vecwise_loop,
|
||||
};
|
||||
use stwo_prover::core::backend::simd::fft::rfft::{fft, get_twiddle_dbls};
|
||||
use stwo_prover::core::backend::simd::fft::transpose_vecs;
|
||||
use stwo_prover::core::backend::simd::m31::PackedBaseField;
|
||||
use stwo_prover::core::fields::m31::BaseField;
|
||||
use stwo_prover::core::poly::circle::CanonicCoset;
|
||||
|
||||
pub fn simd_ifft(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("iffts");
|
||||
|
||||
for log_size in 16..=28 {
|
||||
let domain = CanonicCoset::new(log_size).circle_domain();
|
||||
let twiddle_dbls = get_itwiddle_dbls(domain.half_coset);
|
||||
let twiddle_dbls_refs = twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec();
|
||||
let values: BaseColumn = (0..domain.size()).map(BaseField::from).collect();
|
||||
group.throughput(Throughput::Bytes(size_of_val(&*values.data) as u64));
|
||||
group.bench_function(BenchmarkId::new("simd ifft", log_size), |b| {
|
||||
b.iter_batched(
|
||||
|| values.clone().data,
|
||||
|mut data| unsafe {
|
||||
ifft(
|
||||
transmute(data.as_mut_ptr()),
|
||||
black_box(&twiddle_dbls_refs),
|
||||
black_box(log_size as usize),
|
||||
);
|
||||
},
|
||||
BatchSize::LargeInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub fn simd_ifft_parts(c: &mut Criterion) {
|
||||
const LOG_SIZE: u32 = 14;
|
||||
|
||||
let domain = CanonicCoset::new(LOG_SIZE).circle_domain();
|
||||
let twiddle_dbls = get_itwiddle_dbls(domain.half_coset);
|
||||
let twiddle_dbls_refs = twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec();
|
||||
let values: BaseColumn = (0..domain.size()).map(BaseField::from).collect();
|
||||
|
||||
let mut group = c.benchmark_group("ifft parts");
|
||||
|
||||
// Note: These benchmarks run only on 2^LOG_SIZE elements because of their parameters.
|
||||
// Increasing the figure above won't change the runtime of these benchmarks.
|
||||
group.throughput(Throughput::Bytes(4 << LOG_SIZE));
|
||||
group.bench_function(format!("simd ifft_vecwise_loop 2^{LOG_SIZE}"), |b| {
|
||||
b.iter_batched(
|
||||
|| values.clone().data,
|
||||
|mut values| unsafe {
|
||||
ifft_vecwise_loop(
|
||||
transmute(values.as_mut_ptr()),
|
||||
black_box(&twiddle_dbls_refs),
|
||||
black_box(9),
|
||||
black_box(0),
|
||||
)
|
||||
},
|
||||
BatchSize::LargeInput,
|
||||
);
|
||||
});
|
||||
group.bench_function(format!("simd ifft3_loop 2^{LOG_SIZE}"), |b| {
|
||||
b.iter_batched(
|
||||
|| values.clone().data,
|
||||
|mut values| unsafe {
|
||||
ifft3_loop(
|
||||
transmute(values.as_mut_ptr()),
|
||||
black_box(&twiddle_dbls_refs[3..]),
|
||||
black_box(7),
|
||||
black_box(4),
|
||||
black_box(0),
|
||||
)
|
||||
},
|
||||
BatchSize::LargeInput,
|
||||
);
|
||||
});
|
||||
|
||||
const TRANSPOSE_LOG_SIZE: u32 = 20;
|
||||
let transpose_values: BaseColumn = (0..1 << TRANSPOSE_LOG_SIZE).map(BaseField::from).collect();
|
||||
group.throughput(Throughput::Bytes(4 << TRANSPOSE_LOG_SIZE));
|
||||
group.bench_function(format!("simd transpose_vecs 2^{TRANSPOSE_LOG_SIZE}"), |b| {
|
||||
b.iter_batched(
|
||||
|| transpose_values.clone().data,
|
||||
|mut values| unsafe {
|
||||
transpose_vecs(
|
||||
transmute(values.as_mut_ptr()),
|
||||
black_box(TRANSPOSE_LOG_SIZE as usize - 4),
|
||||
)
|
||||
},
|
||||
BatchSize::LargeInput,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
pub fn simd_rfft(c: &mut Criterion) {
|
||||
const LOG_SIZE: u32 = 20;
|
||||
|
||||
let domain = CanonicCoset::new(LOG_SIZE).circle_domain();
|
||||
let twiddle_dbls = get_twiddle_dbls(domain.half_coset);
|
||||
let twiddle_dbls_refs = twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec();
|
||||
let values: BaseColumn = (0..domain.size()).map(BaseField::from).collect();
|
||||
|
||||
c.bench_function("simd rfft 20bit", |b| {
|
||||
b.iter_with_large_drop(|| unsafe {
|
||||
let mut target = Vec::<PackedBaseField>::with_capacity(values.data.len());
|
||||
#[allow(clippy::uninit_vec)]
|
||||
target.set_len(values.data.len());
|
||||
|
||||
fft(
|
||||
black_box(transmute(values.data.as_ptr())),
|
||||
transmute(target.as_mut_ptr()),
|
||||
black_box(&twiddle_dbls_refs),
|
||||
black_box(LOG_SIZE as usize),
|
||||
)
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
name = benches;
|
||||
config = Criterion::default().sample_size(10);
|
||||
targets = simd_ifft, simd_ifft_parts, simd_rfft);
|
||||
criterion_main!(benches);
|
||||
150
Stwo_wrapper/crates/prover/benches/field.rs
Normal file
150
Stwo_wrapper/crates/prover/benches/field.rs
Normal file
@ -0,0 +1,150 @@
|
||||
use criterion::{criterion_group, criterion_main, Criterion};
|
||||
use num_traits::One;
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use stwo_prover::core::backend::simd::m31::{PackedBaseField, N_LANES};
|
||||
use stwo_prover::core::fields::cm31::CM31;
|
||||
use stwo_prover::core::fields::m31::{BaseField, M31};
|
||||
use stwo_prover::core::fields::qm31::SecureField;
|
||||
|
||||
pub const N_ELEMENTS: usize = 1 << 16;
|
||||
pub const N_STATE_ELEMENTS: usize = 8;
|
||||
|
||||
pub fn m31_operations_bench(c: &mut Criterion) {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let elements: Vec<M31> = (0..N_ELEMENTS).map(|_| rng.gen()).collect();
|
||||
let mut state: [M31; N_STATE_ELEMENTS] = rng.gen();
|
||||
|
||||
c.bench_function("M31 mul", |b| {
|
||||
b.iter(|| {
|
||||
for elem in &elements {
|
||||
for _ in 0..128 {
|
||||
for state_elem in &mut state {
|
||||
*state_elem *= *elem;
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
c.bench_function("M31 add", |b| {
|
||||
b.iter(|| {
|
||||
for elem in &elements {
|
||||
for _ in 0..128 {
|
||||
for state_elem in &mut state {
|
||||
*state_elem += *elem;
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
pub fn cm31_operations_bench(c: &mut Criterion) {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let elements: Vec<CM31> = (0..N_ELEMENTS).map(|_| rng.gen()).collect();
|
||||
let mut state: [CM31; N_STATE_ELEMENTS] = rng.gen();
|
||||
|
||||
c.bench_function("CM31 mul", |b| {
|
||||
b.iter(|| {
|
||||
for elem in &elements {
|
||||
for _ in 0..128 {
|
||||
for state_elem in &mut state {
|
||||
*state_elem *= *elem;
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
c.bench_function("CM31 add", |b| {
|
||||
b.iter(|| {
|
||||
for elem in &elements {
|
||||
for _ in 0..128 {
|
||||
for state_elem in &mut state {
|
||||
*state_elem += *elem;
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
pub fn qm31_operations_bench(c: &mut Criterion) {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let elements: Vec<SecureField> = (0..N_ELEMENTS).map(|_| rng.gen()).collect();
|
||||
let mut state: [SecureField; N_STATE_ELEMENTS] = rng.gen();
|
||||
|
||||
c.bench_function("SecureField mul", |b| {
|
||||
b.iter(|| {
|
||||
for elem in &elements {
|
||||
for _ in 0..128 {
|
||||
for state_elem in &mut state {
|
||||
*state_elem *= *elem;
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
c.bench_function("SecureField add", |b| {
|
||||
b.iter(|| {
|
||||
for elem in &elements {
|
||||
for _ in 0..128 {
|
||||
for state_elem in &mut state {
|
||||
*state_elem += *elem;
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
pub fn simd_m31_operations_bench(c: &mut Criterion) {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let elements: Vec<PackedBaseField> = (0..N_ELEMENTS / N_LANES).map(|_| rng.gen()).collect();
|
||||
let mut states = vec![PackedBaseField::broadcast(BaseField::one()); N_STATE_ELEMENTS];
|
||||
|
||||
c.bench_function("mul_simd", |b| {
|
||||
b.iter(|| {
|
||||
for elem in elements.iter() {
|
||||
for _ in 0..128 {
|
||||
for state in states.iter_mut() {
|
||||
*state *= *elem;
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
c.bench_function("add_simd", |b| {
|
||||
b.iter(|| {
|
||||
for elem in elements.iter() {
|
||||
for _ in 0..128 {
|
||||
for state in states.iter_mut() {
|
||||
*state += *elem;
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
c.bench_function("sub_simd", |b| {
|
||||
b.iter(|| {
|
||||
for elem in elements.iter() {
|
||||
for _ in 0..128 {
|
||||
for state in states.iter_mut() {
|
||||
*state -= *elem;
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
name = benches;
|
||||
config = Criterion::default().sample_size(10);
|
||||
targets = m31_operations_bench, cm31_operations_bench, qm31_operations_bench,
|
||||
simd_m31_operations_bench);
|
||||
criterion_main!(benches);
|
||||
35
Stwo_wrapper/crates/prover/benches/fri.rs
Normal file
35
Stwo_wrapper/crates/prover/benches/fri.rs
Normal file
@ -0,0 +1,35 @@
|
||||
use criterion::{black_box, criterion_group, criterion_main, Criterion};
|
||||
use stwo_prover::core::backend::CpuBackend;
|
||||
use stwo_prover::core::fields::m31::BaseField;
|
||||
use stwo_prover::core::fields::qm31::SecureField;
|
||||
use stwo_prover::core::fields::secure_column::SecureColumnByCoords;
|
||||
use stwo_prover::core::fri::FriOps;
|
||||
use stwo_prover::core::poly::circle::{CanonicCoset, PolyOps};
|
||||
use stwo_prover::core::poly::line::{LineDomain, LineEvaluation};
|
||||
|
||||
fn folding_benchmark(c: &mut Criterion) {
|
||||
const LOG_SIZE: u32 = 12;
|
||||
let domain = LineDomain::new(CanonicCoset::new(LOG_SIZE + 1).half_coset());
|
||||
let evals = LineEvaluation::new(
|
||||
domain,
|
||||
SecureColumnByCoords {
|
||||
columns: std::array::from_fn(|i| {
|
||||
vec![BaseField::from_u32_unchecked(i as u32); 1 << LOG_SIZE]
|
||||
}),
|
||||
},
|
||||
);
|
||||
let alpha = SecureField::from_u32_unchecked(2213980, 2213981, 2213982, 2213983);
|
||||
let twiddles = CpuBackend::precompute_twiddles(domain.coset());
|
||||
c.bench_function("fold_line", |b| {
|
||||
b.iter(|| {
|
||||
black_box(CpuBackend::fold_line(
|
||||
black_box(&evals),
|
||||
black_box(alpha),
|
||||
&twiddles,
|
||||
));
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group!(benches, folding_benchmark);
|
||||
criterion_main!(benches);
|
||||
104
Stwo_wrapper/crates/prover/benches/lookups.rs
Normal file
104
Stwo_wrapper/crates/prover/benches/lookups.rs
Normal file
@ -0,0 +1,104 @@
|
||||
use criterion::{criterion_group, criterion_main, BatchSize, Criterion};
|
||||
use rand::distributions::{Distribution, Standard};
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use stwo_prover::core::backend::simd::SimdBackend;
|
||||
use stwo_prover::core::backend::CpuBackend;
|
||||
use stwo_prover::core::channel::Blake2sChannel;
|
||||
use stwo_prover::core::fields::Field;
|
||||
use stwo_prover::core::lookups::gkr_prover::{prove_batch, GkrOps, Layer};
|
||||
use stwo_prover::core::lookups::mle::{Mle, MleOps};
|
||||
|
||||
const LOG_N_ROWS: u32 = 16;
|
||||
|
||||
fn bench_gkr_grand_product<B: GkrOps>(c: &mut Criterion, id: &str) {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let layer = Layer::<B>::GrandProduct(gen_random_mle(&mut rng, LOG_N_ROWS));
|
||||
c.bench_function(&format!("{id} grand product lookup 2^{LOG_N_ROWS}"), |b| {
|
||||
b.iter_batched(
|
||||
|| layer.clone(),
|
||||
|layer| prove_batch(&mut Blake2sChannel::default(), vec![layer]),
|
||||
BatchSize::LargeInput,
|
||||
)
|
||||
});
|
||||
c.bench_function(
|
||||
&format!("{id} grand product lookup batch 4x 2^{LOG_N_ROWS}"),
|
||||
|b| {
|
||||
b.iter_batched(
|
||||
|| vec![layer.clone(), layer.clone(), layer.clone(), layer.clone()],
|
||||
|layers| prove_batch(&mut Blake2sChannel::default(), layers),
|
||||
BatchSize::LargeInput,
|
||||
)
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
fn bench_gkr_logup_generic<B: GkrOps>(c: &mut Criterion, id: &str) {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let generic_layer = Layer::<B>::LogUpGeneric {
|
||||
numerators: gen_random_mle(&mut rng, LOG_N_ROWS),
|
||||
denominators: gen_random_mle(&mut rng, LOG_N_ROWS),
|
||||
};
|
||||
c.bench_function(&format!("{id} generic logup lookup 2^{LOG_N_ROWS}"), |b| {
|
||||
b.iter_batched(
|
||||
|| generic_layer.clone(),
|
||||
|layer| prove_batch(&mut Blake2sChannel::default(), vec![layer]),
|
||||
BatchSize::LargeInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
fn bench_gkr_logup_multiplicities<B: GkrOps>(c: &mut Criterion, id: &str) {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let multiplicities_layer = Layer::<B>::LogUpMultiplicities {
|
||||
numerators: gen_random_mle(&mut rng, LOG_N_ROWS),
|
||||
denominators: gen_random_mle(&mut rng, LOG_N_ROWS),
|
||||
};
|
||||
c.bench_function(
|
||||
&format!("{id} multiplicities logup lookup 2^{LOG_N_ROWS}"),
|
||||
|b| {
|
||||
b.iter_batched(
|
||||
|| multiplicities_layer.clone(),
|
||||
|layer| prove_batch(&mut Blake2sChannel::default(), vec![layer]),
|
||||
BatchSize::LargeInput,
|
||||
)
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
fn bench_gkr_logup_singles<B: GkrOps>(c: &mut Criterion, id: &str) {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let singles_layer = Layer::<B>::LogUpSingles {
|
||||
denominators: gen_random_mle(&mut rng, LOG_N_ROWS),
|
||||
};
|
||||
c.bench_function(&format!("{id} singles logup lookup 2^{LOG_N_ROWS}"), |b| {
|
||||
b.iter_batched(
|
||||
|| singles_layer.clone(),
|
||||
|layer| prove_batch(&mut Blake2sChannel::default(), vec![layer]),
|
||||
BatchSize::LargeInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
/// Generates a random multilinear polynomial.
|
||||
fn gen_random_mle<B: MleOps<F>, F: Field>(rng: &mut impl Rng, n_variables: u32) -> Mle<B, F>
|
||||
where
|
||||
Standard: Distribution<F>,
|
||||
{
|
||||
Mle::new((0..1 << n_variables).map(|_| rng.gen()).collect())
|
||||
}
|
||||
|
||||
fn gkr_lookup_benches(c: &mut Criterion) {
|
||||
bench_gkr_grand_product::<SimdBackend>(c, "simd");
|
||||
bench_gkr_logup_generic::<SimdBackend>(c, "simd");
|
||||
bench_gkr_logup_multiplicities::<SimdBackend>(c, "simd");
|
||||
bench_gkr_logup_singles::<SimdBackend>(c, "simd");
|
||||
|
||||
bench_gkr_grand_product::<CpuBackend>(c, "cpu");
|
||||
bench_gkr_logup_generic::<CpuBackend>(c, "cpu");
|
||||
bench_gkr_logup_multiplicities::<CpuBackend>(c, "cpu");
|
||||
bench_gkr_logup_singles::<CpuBackend>(c, "cpu");
|
||||
}
|
||||
|
||||
criterion_group!(benches, gkr_lookup_benches);
|
||||
criterion_main!(benches);
|
||||
63
Stwo_wrapper/crates/prover/benches/matrix.rs
Normal file
63
Stwo_wrapper/crates/prover/benches/matrix.rs
Normal file
@ -0,0 +1,63 @@
|
||||
use criterion::{black_box, criterion_group, criterion_main, Criterion};
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use stwo_prover::core::fields::m31::{M31, P};
|
||||
use stwo_prover::core::fields::qm31::QM31;
|
||||
use stwo_prover::math::matrix::{RowMajorMatrix, SquareMatrix};
|
||||
|
||||
const MATRIX_SIZE: usize = 24;
|
||||
const QM31_MATRIX_SIZE: usize = 6;
|
||||
|
||||
// TODO(ShaharS): Share code with other benchmarks.
|
||||
fn row_major_matrix_multiplication_bench(c: &mut Criterion) {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
|
||||
let matrix_m31 = RowMajorMatrix::<M31, MATRIX_SIZE>::new(
|
||||
(0..MATRIX_SIZE.pow(2))
|
||||
.map(|_| rng.gen())
|
||||
.collect::<Vec<M31>>(),
|
||||
);
|
||||
|
||||
let matrix_qm31 = RowMajorMatrix::<QM31, QM31_MATRIX_SIZE>::new(
|
||||
(0..QM31_MATRIX_SIZE.pow(2))
|
||||
.map(|_| rng.gen())
|
||||
.collect::<Vec<QM31>>(),
|
||||
);
|
||||
|
||||
// Create vector M31.
|
||||
let vec: [M31; MATRIX_SIZE] = rng.gen();
|
||||
|
||||
// Create vector QM31.
|
||||
let vec_qm31: [QM31; QM31_MATRIX_SIZE] = [(); QM31_MATRIX_SIZE].map(|_| {
|
||||
QM31::from_u32_unchecked(
|
||||
rng.gen::<u32>() % P,
|
||||
rng.gen::<u32>() % P,
|
||||
rng.gen::<u32>() % P,
|
||||
rng.gen::<u32>() % P,
|
||||
)
|
||||
});
|
||||
|
||||
// bench matrix multiplication.
|
||||
c.bench_function(
|
||||
&format!("RowMajorMatrix M31 {size}x{size} mul", size = MATRIX_SIZE),
|
||||
|b| {
|
||||
b.iter(|| {
|
||||
black_box(matrix_m31.mul(vec));
|
||||
})
|
||||
},
|
||||
);
|
||||
c.bench_function(
|
||||
&format!(
|
||||
"QM31 RowMajorMatrix {size}x{size} mul",
|
||||
size = QM31_MATRIX_SIZE
|
||||
),
|
||||
|b| {
|
||||
b.iter(|| {
|
||||
black_box(matrix_qm31.mul(vec_qm31));
|
||||
})
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
criterion_group!(benches, row_major_matrix_multiplication_bench);
|
||||
criterion_main!(benches);
|
||||
38
Stwo_wrapper/crates/prover/benches/merkle.rs
Normal file
38
Stwo_wrapper/crates/prover/benches/merkle.rs
Normal file
@ -0,0 +1,38 @@
|
||||
#![feature(iter_array_chunks)]
|
||||
|
||||
use criterion::{criterion_group, criterion_main, Criterion, Throughput};
|
||||
use itertools::Itertools;
|
||||
use num_traits::Zero;
|
||||
use stwo_prover::core::backend::simd::SimdBackend;
|
||||
use stwo_prover::core::backend::{Col, CpuBackend};
|
||||
use stwo_prover::core::fields::m31::{BaseField, N_BYTES_FELT};
|
||||
use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleHasher;
|
||||
use stwo_prover::core::vcs::ops::MerkleOps;
|
||||
|
||||
const LOG_N_ROWS: u32 = 16;
|
||||
|
||||
const LOG_N_COLS: u32 = 8;
|
||||
|
||||
fn bench_blake2s_merkle<B: MerkleOps<Blake2sMerkleHasher>>(c: &mut Criterion, id: &str) {
|
||||
let col: Col<B, BaseField> = (0..1 << LOG_N_ROWS).map(|_| BaseField::zero()).collect();
|
||||
let cols = (0..1 << LOG_N_COLS).map(|_| col.clone()).collect_vec();
|
||||
let col_refs = cols.iter().collect_vec();
|
||||
let mut group = c.benchmark_group("merkle throughput");
|
||||
let n_elements = 1 << (LOG_N_COLS + LOG_N_ROWS);
|
||||
group.throughput(Throughput::Elements(n_elements));
|
||||
group.throughput(Throughput::Bytes(N_BYTES_FELT as u64 * n_elements));
|
||||
group.bench_function(&format!("{id} merkle"), |b| {
|
||||
b.iter_with_large_drop(|| B::commit_on_layer(LOG_N_ROWS, None, &col_refs))
|
||||
});
|
||||
}
|
||||
|
||||
fn blake2s_merkle_benches(c: &mut Criterion) {
|
||||
bench_blake2s_merkle::<SimdBackend>(c, "simd");
|
||||
bench_blake2s_merkle::<CpuBackend>(c, "cpu");
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
name = benches;
|
||||
config = Criterion::default().sample_size(10);
|
||||
targets = blake2s_merkle_benches);
|
||||
criterion_main!(benches);
|
||||
81
Stwo_wrapper/crates/prover/benches/pcs.rs
Normal file
81
Stwo_wrapper/crates/prover/benches/pcs.rs
Normal file
@ -0,0 +1,81 @@
|
||||
use std::iter;
|
||||
|
||||
use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion};
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use stwo_prover::core::backend::simd::SimdBackend;
|
||||
use stwo_prover::core::backend::{BackendForChannel, CpuBackend};
|
||||
use stwo_prover::core::channel::Blake2sChannel;
|
||||
use stwo_prover::core::fields::m31::BaseField;
|
||||
use stwo_prover::core::pcs::CommitmentTreeProver;
|
||||
use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation};
|
||||
use stwo_prover::core::poly::twiddles::TwiddleTree;
|
||||
use stwo_prover::core::poly::BitReversedOrder;
|
||||
use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel;
|
||||
|
||||
const LOG_COSET_SIZE: u32 = 20;
|
||||
const LOG_BLOWUP_FACTOR: u32 = 1;
|
||||
const N_POLYS: usize = 16;
|
||||
|
||||
fn benched_fn<B: BackendForChannel<Blake2sMerkleChannel>>(
|
||||
evals: Vec<CircleEvaluation<B, BaseField, BitReversedOrder>>,
|
||||
channel: &mut Blake2sChannel,
|
||||
twiddles: &TwiddleTree<B>,
|
||||
) {
|
||||
let polys = evals
|
||||
.into_iter()
|
||||
.map(|eval| eval.interpolate_with_twiddles(twiddles))
|
||||
.collect();
|
||||
|
||||
CommitmentTreeProver::<B, Blake2sMerkleChannel>::new(
|
||||
polys,
|
||||
LOG_BLOWUP_FACTOR,
|
||||
channel,
|
||||
twiddles,
|
||||
);
|
||||
}
|
||||
|
||||
fn bench_pcs<B: BackendForChannel<Blake2sMerkleChannel>>(c: &mut Criterion, id: &str) {
|
||||
let small_domain = CanonicCoset::new(LOG_COSET_SIZE);
|
||||
let big_domain = CanonicCoset::new(LOG_COSET_SIZE + LOG_BLOWUP_FACTOR);
|
||||
let twiddles = B::precompute_twiddles(big_domain.half_coset());
|
||||
let mut channel = Blake2sChannel::default();
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
|
||||
let evals: Vec<CircleEvaluation<B, BaseField, BitReversedOrder>> = iter::repeat_with(|| {
|
||||
CircleEvaluation::new(
|
||||
small_domain.circle_domain(),
|
||||
(0..1 << LOG_COSET_SIZE).map(|_| rng.gen()).collect(),
|
||||
)
|
||||
})
|
||||
.take(N_POLYS)
|
||||
.collect();
|
||||
|
||||
c.bench_function(
|
||||
&format!("{id} polynomial commitment 2^{LOG_COSET_SIZE}"),
|
||||
|b| {
|
||||
b.iter_batched(
|
||||
|| evals.clone(),
|
||||
|evals| {
|
||||
benched_fn::<B>(
|
||||
black_box(evals),
|
||||
black_box(&mut channel),
|
||||
black_box(&twiddles),
|
||||
)
|
||||
},
|
||||
BatchSize::LargeInput,
|
||||
);
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
fn pcs_benches(c: &mut Criterion) {
|
||||
bench_pcs::<SimdBackend>(c, "simd");
|
||||
bench_pcs::<CpuBackend>(c, "cpu");
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
name = benches;
|
||||
config = Criterion::default().sample_size(10);
|
||||
targets = pcs_benches);
|
||||
criterion_main!(benches);
|
||||
18
Stwo_wrapper/crates/prover/benches/poseidon.rs
Normal file
18
Stwo_wrapper/crates/prover/benches/poseidon.rs
Normal file
@ -0,0 +1,18 @@
|
||||
use criterion::{criterion_group, criterion_main, Criterion, Throughput};
|
||||
use stwo_prover::core::pcs::PcsConfig;
|
||||
use stwo_prover::examples::poseidon::prove_poseidon;
|
||||
|
||||
pub fn simd_poseidon(c: &mut Criterion) {
|
||||
const LOG_N_INSTANCES: u32 = 18;
|
||||
let mut group = c.benchmark_group("poseidon2");
|
||||
group.throughput(Throughput::Elements(1u64 << LOG_N_INSTANCES));
|
||||
group.bench_function(format!("poseidon2 2^{} instances", LOG_N_INSTANCES), |b| {
|
||||
b.iter(|| prove_poseidon(LOG_N_INSTANCES, PcsConfig::default()));
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
name = bit_rev;
|
||||
config = Criterion::default().sample_size(10);
|
||||
targets = simd_poseidon);
|
||||
criterion_main!(bit_rev);
|
||||
19
Stwo_wrapper/crates/prover/benches/prefix_sum.rs
Normal file
19
Stwo_wrapper/crates/prover/benches/prefix_sum.rs
Normal file
@ -0,0 +1,19 @@
|
||||
use criterion::{criterion_group, criterion_main, BatchSize, Criterion};
|
||||
use stwo_prover::core::backend::simd::column::BaseColumn;
|
||||
use stwo_prover::core::backend::simd::prefix_sum::inclusive_prefix_sum;
|
||||
use stwo_prover::core::fields::m31::BaseField;
|
||||
|
||||
pub fn simd_prefix_sum_bench(c: &mut Criterion) {
|
||||
const LOG_SIZE: u32 = 24;
|
||||
let evals: BaseColumn = (0..1 << LOG_SIZE).map(BaseField::from).collect();
|
||||
c.bench_function(&format!("simd prefix_sum 2^{LOG_SIZE}"), |b| {
|
||||
b.iter_batched(
|
||||
|| evals.clone(),
|
||||
inclusive_prefix_sum,
|
||||
BatchSize::LargeInput,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group!(benches, simd_prefix_sum_bench);
|
||||
criterion_main!(benches);
|
||||
55
Stwo_wrapper/crates/prover/benches/quotients.rs
Normal file
55
Stwo_wrapper/crates/prover/benches/quotients.rs
Normal file
@ -0,0 +1,55 @@
|
||||
#![feature(iter_array_chunks)]
|
||||
|
||||
use criterion::{black_box, criterion_group, criterion_main, Criterion};
|
||||
use itertools::Itertools;
|
||||
use stwo_prover::core::backend::cpu::CpuBackend;
|
||||
use stwo_prover::core::backend::simd::SimdBackend;
|
||||
use stwo_prover::core::circle::SECURE_FIELD_CIRCLE_GEN;
|
||||
use stwo_prover::core::fields::m31::BaseField;
|
||||
use stwo_prover::core::fields::qm31::SecureField;
|
||||
use stwo_prover::core::pcs::quotients::{ColumnSampleBatch, QuotientOps};
|
||||
use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation};
|
||||
use stwo_prover::core::poly::BitReversedOrder;
|
||||
|
||||
// TODO(andrew): Consider removing const generics and making all sizes the same.
|
||||
fn bench_quotients<B: QuotientOps, const LOG_N_ROWS: u32, const LOG_N_COLS: u32>(
|
||||
c: &mut Criterion,
|
||||
id: &str,
|
||||
) {
|
||||
let domain = CanonicCoset::new(LOG_N_ROWS).circle_domain();
|
||||
let values = (0..domain.size()).map(BaseField::from).collect();
|
||||
let col = CircleEvaluation::<B, BaseField, BitReversedOrder>::new(domain, values);
|
||||
let cols = (0..1 << LOG_N_COLS).map(|_| col.clone()).collect_vec();
|
||||
let col_refs = cols.iter().collect_vec();
|
||||
let random_coeff = SecureField::from_u32_unchecked(0, 1, 2, 3);
|
||||
let a = SecureField::from_u32_unchecked(5, 6, 7, 8);
|
||||
let samples = vec![ColumnSampleBatch {
|
||||
point: SECURE_FIELD_CIRCLE_GEN,
|
||||
columns_and_values: (0..1 << LOG_N_COLS).map(|i| (i, a)).collect(),
|
||||
}];
|
||||
c.bench_function(
|
||||
&format!("{id} quotients 2^{LOG_N_COLS} x 2^{LOG_N_ROWS}"),
|
||||
|b| {
|
||||
b.iter_with_large_drop(|| {
|
||||
B::accumulate_quotients(
|
||||
black_box(domain),
|
||||
black_box(&col_refs),
|
||||
black_box(random_coeff),
|
||||
black_box(&samples),
|
||||
1,
|
||||
)
|
||||
})
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
fn quotients_benches(c: &mut Criterion) {
|
||||
bench_quotients::<SimdBackend, 20, 8>(c, "simd");
|
||||
bench_quotients::<CpuBackend, 16, 8>(c, "cpu");
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
name = benches;
|
||||
config = Criterion::default().sample_size(10);
|
||||
targets = quotients_benches);
|
||||
criterion_main!(benches);
|
||||
348
Stwo_wrapper/crates/prover/proof.json
Normal file
348
Stwo_wrapper/crates/prover/proof.json
Normal file
@ -0,0 +1,348 @@
|
||||
{
|
||||
"commitments" :
|
||||
["34328580272026076035687604093297365442785733592720865218001799813393342152908",
|
||||
"38388381845372648579572899115609862601821983406101214230086519922780265042634"],
|
||||
|
||||
"sampled_values_0" :
|
||||
[["1","0","0","0"],
|
||||
["2129160320","1109509513","787887008","1676461964"],
|
||||
["262908602","915488457","1893945291","1774327476"],
|
||||
["894719153","1570509766","1424186619","204092576"],
|
||||
["397490811","836398274","1615765624","2013800563"],
|
||||
["1022303904","276983775","1064742229","165204856"],
|
||||
["1200363525","170838026","524999776","156116441"],
|
||||
["850733526","448725560","1521962209","1318190714"],
|
||||
["1187866075","1705588092","924088348","490002418"],
|
||||
["2033565088","996780784","1820235518","2048788344"],
|
||||
["2061590372","1150986157","711772586","1511398564"],
|
||||
["1066623954","530384603","1890251380","1699008129"],
|
||||
["734047580","1685768538","505142109","787113212"],
|
||||
["2030904700","99932423","695391286","1736941035"],
|
||||
["1580330105","932031717","1705998668","146411959"],
|
||||
["1585732224","1556242253","941668238","1998570239"],
|
||||
["199481433","2123320403","1257464748","1663811899"],
|
||||
["2139019524","1547107722","728449250","1941851166"],
|
||||
["752079023","268472135","1465850435","16510773"],
|
||||
["1279312817","63252415","442230579","1560954631"],
|
||||
["1074859131","137997593","2118329011","652535723"],
|
||||
["297567647","1483381078","1941495981","599737348"],
|
||||
["1735543786","1420676479","1354982762","1114211268"],
|
||||
["1691705401","1143446295","1748115479","1666756627"],
|
||||
["955696743","2077778309","736065989","1319443838"],
|
||||
["1076874307","1001483910","1702287354","819727011"],
|
||||
["1134989244","1823710400","2067694105","1098263343"],
|
||||
["1793642608","961404475","1279773056","1815400043"],
|
||||
["739677274","1827877577","838562378","171296720"],
|
||||
["2036367121","1901888610","289723252","2014426907"],
|
||||
["330020507","436937516","2113056521","1828501207"],
|
||||
["1359068814","583899921","734628376","1223217137"],
|
||||
["1319501520","1242972089","1202216521","1285024997"],
|
||||
["681182370","1569622309","1574376904","1563950435"],
|
||||
["1204519566","483612224","1677731115","1667757584"],
|
||||
["330284364","917877098","57538161","179869993"],
|
||||
["2056561198","119768893","740294154","1454562198"],
|
||||
["79009084","545196641","13388962","1973400144"],
|
||||
["885977898","1973300145","37115619","957100699"],
|
||||
["1937449867","1777683674","1983002799","757662558"],
|
||||
["344927561","357845689","26887161","664585634"],
|
||||
["1462268220","615463524","209500386","44308852"],
|
||||
["570984705","2022111132","1404632615","2119081660"],
|
||||
["13183327","1584451280","1216116653","316345540"],
|
||||
["1497965915","705236857","1892466476","1068567492"],
|
||||
["1758694676","1408790161","1140545981","315723937"],
|
||||
["645308461","1125824784","1786470558","1240927727"],
|
||||
["1213464061","470930291","1718629724","1149088875"],
|
||||
["214577693","1578610321","2133720991","226291629"],
|
||||
["1357706729","2097875841","1767996253","1478111500"],
|
||||
["1154658683","752162439","2018723944","163997560"],
|
||||
["1051993583","703716977","379706674","487262860"],
|
||||
["1017692573","2060296775","2001023083","1064213951"],
|
||||
["1042587725","1701108370","204550428","904590130"],
|
||||
["1115340870","743420370","1927225111","1276396551"],
|
||||
["493638626","1874789377","47342513","209203758"],
|
||||
["1558586505","83459476","247638703","1975504267"],
|
||||
["2097068784","954319448","367516919","1545761518"],
|
||||
["1655645294","352838520","1307263981","1110198118"],
|
||||
["1169856046","1925368371","1362317240","1926032147"],
|
||||
["1940113709","885624001","1395047654","80053995"],
|
||||
["1778932990","25092730","201117282","1724571908"],
|
||||
["2096327738","233411984","1247443120","713989449"],
|
||||
["808532602","136577890","1015579288","38900716"],
|
||||
["1182257782","1186245376","1451332036","2080170103"],
|
||||
["1662610758","1505542080","1038243031","1889715771"],
|
||||
["440146119","942837214","1440484295","1593949278"],
|
||||
["46258268","1884246120","164930024","2050584510"],
|
||||
["1198954868","1079638495","1424072583","1028611344"],
|
||||
["2112984649","1531382496","1873151714","1818301795"],
|
||||
["1554382282","253920307","1641628530","1378998084"],
|
||||
["857898234","686236793","2091871553","184978860"],
|
||||
["2049153599","6111471","1579475775","32492894"],
|
||||
["1371356596","679072793","1547377985","354305233"],
|
||||
["1799882226","1201472049","1592617716","125534957"],
|
||||
["1277144880","253726080","1800145982","1125162267"],
|
||||
["1577717920","440984421","1377891036","846453148"],
|
||||
["1952731919","1710992214","673668053","1871913638"],
|
||||
["1559011028","2060945859","719954448","1356468891"],
|
||||
["1961642242","1693473944","1300152522","412222111"],
|
||||
["861208187","1242659514","977183954","38730935"],
|
||||
["1016984917","1368361439","2106430139","1225979890"],
|
||||
["1427754325","1482206106","1465316380","1096279813"],
|
||||
["566051043","2025874544","234976335","1482256978"],
|
||||
["1750543495","1494541462","374330732","411642241"],
|
||||
["230654343","55625728","136463431","1099606808"],
|
||||
["1172218793","1260458608","1314942990","75527287"],
|
||||
["1824515276","916178746","1300275105","370626746"],
|
||||
["915931367","987018043","56193044","617907884"],
|
||||
["1934695822","1112844637","609268252","1972086910"],
|
||||
["619631651","152029630","1979976905","292597437"],
|
||||
["62258350","1890115432","1373605674","1505619938"],
|
||||
["1770422019","1398189304","1773172351","1576001433"],
|
||||
["650940868","1756047014","1764798953","1146887875"],
|
||||
["1746945043","528205234","778346028","1797468521"],
|
||||
["760802416","1479409742","1556974632","1307498378"],
|
||||
["102511022","1787975482","968854748","1010240763"],
|
||||
["330722054","2046294448","14132125","1822414050"],
|
||||
["943548871","1770900623","1861740461","1290634078"],
|
||||
["1402661415","1361511065","1784889120","837615360"]],
|
||||
|
||||
"sampled_values_1" :
|
||||
[["712066144","1576368753","626134398","426337436"],
|
||||
["160634493","1096735733","992622982","964509862"],
|
||||
["208900621","1128739590","1423579079","1688318061"],
|
||||
["1029182234","1152361165","571476481","1593867154"]],
|
||||
|
||||
"decommitment_0" :
|
||||
["24311567319749512546399129581715033328970605051392227451685196018312506896509",
|
||||
"8450134967305372517473027560161707471995673370792264153422077885080332622841",
|
||||
"6431507699794114682519586182713221908058047520896405293833270087934517909753",
|
||||
"303109001984349840640377328716025252051982378448629744935456455431709129012",
|
||||
"47167328465744900593371601186109726758160197572292632388959155138584359581158",
|
||||
"50584492046778480438774038937088410409133167768957478525289857065775850658491",
|
||||
"30584798499699103841624545814425958941934653399588880797257122471101102880636",
|
||||
"33441256878213890325682161124370878299436204406591246133637659120215439522803",
|
||||
"3288124068330032280185519028600654292250668929588668389702892483946668251740",
|
||||
"29852774919556057664485671676242264613416836486089146650713214180894511265116",
|
||||
"12482060975231949385592255321766253365687502822944549845564491620341379321204",
|
||||
"46285234163162336949700608657672147469543559995399282843606812790099228411758",
|
||||
"20807128972645591294726020136444795908525656782422245307591812614900798799914"],
|
||||
|
||||
"decommitment_1" :
|
||||
["18063303111481257844109225560025890393366258018933166919604543575686388632162",
|
||||
"1676364734386980395984608216327451243278421019544108756198322792517099196249",
|
||||
"5278661052518480850653886996628582549184134231869598116690316714367933376948",
|
||||
"21983822689977371558234298346357617674436224016274009820764238516520240403273",
|
||||
"605332543427376153930374757063581881998320956602375739165671986207155079359",
|
||||
"33771702906041565783498389127165108212044382608172583325407071671862086994048",
|
||||
"6930451780154275491146135719028766497977496109537963233244808739657647563071",
|
||||
"4564117668410212714684125903928600765456322272915099403587425756488534507713",
|
||||
"337808767671877648828499299861821973796749820854854708379479049898835100991",
|
||||
"2725840354457305623692571800192492803162041315546256970381708693201407812833",
|
||||
"34716495111790106826563330917176360656701717867702196654990744866499300990003",
|
||||
"49719870445464463616785535809529171382800153139923763422202182182379572350737",
|
||||
"48664746464641275030915461677298150155193593333108431337941029583245720868695",
|
||||
"32557886668297237033601675259512842580727821006475208499640136322794706303894",
|
||||
"15014835402414421167586357788116276188694467622586221351644991310645286648480",
|
||||
"12973705814659120511327850727547995427054971555827754777395732787633567627149",
|
||||
"15379912850866472398958956306527914195058439699787840041152620034933267404138",
|
||||
"19859070819439084101412868355121176941090844577507824922960697697668791429525",
|
||||
"38273559034692361632775489704953448699371080776239846995670381153186834620044"],
|
||||
|
||||
"queried_values_0" :
|
||||
[["1","1","1","1","1","1"],
|
||||
["730457281","730490049","28918683","28885915","1656126010","1656093242"],
|
||||
["855614122","1238037465","1836504291","355791428","757095818","467806903"],
|
||||
["674179888","1530445315","1720543014","76190330","1475912409","1017215862"],
|
||||
["1142290008","1148671853","1619097781","938511401","904357795","257652679"],
|
||||
["1679234056","1355264641","2139729457","574756654","604307234","1146556949"],
|
||||
["1000500309","2008905806","1442759180","598876729","1786070690","1072293976"],
|
||||
["1119085545","2133345582","135683580","216214405","1049766224","943727969"],
|
||||
["206423262","2047139937","305085364","1422472664","1826554088","1032095092"],
|
||||
["501238882","1656305868","724710382","1949772461","1426787917","585368894"],
|
||||
["1005468045","1775577441","1042182076","415631363","1067013227","1635705270"],
|
||||
["1776076392","216798814","1525036520","1160666510","1212132211","1915058776"],
|
||||
["859923105","1633989410","182110635","2060185314","1084464822","1129902257"],
|
||||
["489437802","313401022","271315129","357612175","2050381179","647577687"],
|
||||
["1495302158","2052264981","1498165299","1164417520","1050104037","450244199"],
|
||||
["1084986392","398966983","808449145","1733554138","2068501028","659474347"],
|
||||
["399458768","1789245133","1698759035","188433436","1794535430","364419824"],
|
||||
["2013965647","722839714","928854328","124488895","1378959529","952886009"],
|
||||
["1334765706","193402268","471076108","640800921","1998121783","961582406"],
|
||||
["1067762968","381831281","560459357","1025929344","181659877","1922040224"],
|
||||
["1993303462","467991218","849673597","744722836","239634354","329631295"],
|
||||
["785794488","1649178388","672964420","1281255462","900602801","271501809"],
|
||||
["857859728","1325395820","985014020","1094321795","259553347","774587048"],
|
||||
["1214640090","1588569866","871717820","1131833706","1625896842","1635087550"],
|
||||
["796549205","931495223","2018253108","1395065060","158209751","1160478135"],
|
||||
["883143962","729115354","190207821","839273168","1668931939","2074584689"],
|
||||
["1490296658","1846956206","1610364850","56422972","160482417","681872093"],
|
||||
["1270585092","1910190167","464113273","613529242","1027101122","1014185686"],
|
||||
["1456043179","1999662961","193940913","678382864","39040067","1236859818"],
|
||||
["1626243617","901735777","1703169024","911300891","1640727682","1121874896"],
|
||||
["492192896","15672698","319327174","1727120334","1965889437","114404366"],
|
||||
["407079019","949462637","255390508","1753162095","501134776","1457122467"],
|
||||
["1478573872","1439193434","1053200675","1001140887","1553935777","1253681552"],
|
||||
["183135520","946237525","1802924023","1831496784","1893117930","1830486286"],
|
||||
["234902670","1169030504","196055115","1323151968","855748623","1328842866"],
|
||||
["1150999776","1338824346","2072101698","774206263","1967350016","1808817867"],
|
||||
["924341552","1430286424","511268814","825025920","1061850574","1954646566"],
|
||||
["302634890","314434153","1692670768","1822915313","1244352075","1953834230"],
|
||||
["1576167467","687837005","2116136752","144109400","1590157548","1634932462"],
|
||||
["396756275","1272134898","1207308240","818219166","1314182589","109494000"],
|
||||
["846425160","897737569","757312164","826009489","1019831588","1977463051"],
|
||||
["2065801114","1918982367","1548689186","2082631803","298112070","383438809"],
|
||||
["1034102289","461735180","2115581275","1343026598","1229979058","1021418523"],
|
||||
["1784173874","166635387","547550115","1094693960","573193735","451367040"],
|
||||
["119818313","659105018","1741377697","8940733","911200334","511474518"],
|
||||
["1511949880","1315119529","1267019200","2134944693","878254810","375758264"],
|
||||
["1203050254","156394547","1348568635","412863443","1068659960","1407913814"],
|
||||
["361779719","130417374","89109096","117994876","1151322919","863143484"],
|
||||
["1007476533","989566160","138644964","1672742874","540141118","1296408100"],
|
||||
["1824241144","2051199719","1863718547","2109877864","36689613","1055926854"],
|
||||
["791693003","1433717239","991140958","1565955371","1839976870","1163947838"],
|
||||
["1267320759","2102593211","1831360854","1691591439","1672201908","61327345"],
|
||||
["301343164","277158258","627925439","577508975","1896464649","907629062"],
|
||||
["1964268932","929590164","1529686876","68630644","1663063136","254082844"],
|
||||
["693529348","1815295486","1660565870","1226857377","156312343","1500907098"],
|
||||
["1723158753","252348225","253985470","52424437","1605949937","576572581"],
|
||||
["1781792048","1497492716","1951824572","1156925855","863650708","156447987"],
|
||||
["876432605","458503399","283092867","247883110","1227074181","966219235"],
|
||||
["1581118191","66527915","1577039825","1227961402","1738412997","1862462297"],
|
||||
["679458448","338624032","34185999","253532412","65409631","563033132"],
|
||||
["1011967612","1898273226","2124401156","105282260","1188226330","123913515"],
|
||||
["1432513005","1083162825","1299704150","1184276814","339749370","1064298821"],
|
||||
["145070927","1250457746","1977306722","1035124433","322154361","1782232869"],
|
||||
["1934256728","1269667423","248999825","267009036","132507662","474021315"],
|
||||
["1694871453","535187899","981724703","1180312550","74370795","277702656"],
|
||||
["1850841912","1037634121","1967377497","1755127193","1566449422","1939039785"],
|
||||
["1315699598","476111459","2058733537","332263289","1592057567","874912616"],
|
||||
["774536537","170060616","2086574090","47894465","778021586","2115296942"],
|
||||
["322468558","24934377","637275739","1596346002","1896623296","1814433409"],
|
||||
["766517428","1263076038","358941187","2070217919","2108397185","1587546402"],
|
||||
["345404490","1065320570","1231275245","1037359122","1286389839","2070140848"],
|
||||
["746521574","835067673","311114030","1586400488","1406022058","1284151326"],
|
||||
["1857315969","431410759","825259098","1717904860","503708539","2097758215"],
|
||||
["1879479734","1863555039","2108235515","1833922769","1562707156","49484002"],
|
||||
["1366768987","1050390036","1491845132","666041968","74368055","1254335623"],
|
||||
["188857287","1161878039","1771805176","1457666227","1157868840","486461459"],
|
||||
["261764705","1577846886","1332322961","10423372","640027252","1086814656"],
|
||||
["111907709","542625019","2021749229","2013008690","523703611","1328833940"],
|
||||
["1270684472","1675989474","394214608","538100201","1984625073","1560563159"],
|
||||
["666555709","852557426","651115051","1878827907","953346499","619017191"],
|
||||
["747972907","149382079","1393306586","1394823957","960994901","536632180"],
|
||||
["80535893","229602380","1817483938","1455260088","484432","1869486290"],
|
||||
["443556931","253108261","1609174393","1245931188","752691602","1668543792"],
|
||||
["745497042","854686466","1834097777","642389535","1284043061","896553209"],
|
||||
["532777064","1491985134","200157005","1378855967","1159213374","1797221037"],
|
||||
["933463176","813761538","1124049829","1988347055","2115297439","1836576920"],
|
||||
["194436043","1437728625","43998833","786326005","1130428925","1424571033"],
|
||||
["2108272636","1410841489","753065553","2020187193","1644376367","670324352"],
|
||||
["1362448669","752702510","1740531646","47989265","588634","1940480814"],
|
||||
["439960422","528245604","179496898","235775013","59000527","1903150726"],
|
||||
["103605138","249162711","1971219628","1958189530","423278905","1318354885"],
|
||||
["321504059","1595801356","596911575","1361967073","459661104","599048233"],
|
||||
["1610552125","73166668","444776743","1820306524","1180674369","1570356756"],
|
||||
["1283703846","1024562975","958477092","1329464736","1758672211","899108631"],
|
||||
["713626137","904634570","1902566483","1938333063","447549083","703262660"],
|
||||
["1417291696","1717451368","354584524","832751684","930128006","1037860604"],
|
||||
["1618745108","79533863","301008038","2091942909","1221962725","1524945081"],
|
||||
["1490224031","349040760","393684137","484089443","1912848485","1790207999"],
|
||||
["1930411520","642009628","1138820074","31855314","1177766391","1913457637"],
|
||||
["618628623","139430131","904498895","925273128","2111653256","1012250155"]],
|
||||
|
||||
"queried_values_1" : [["316772341","1526280133","663010112","224983897","510598760","1109503351"],
|
||||
["754832207","435790299","883623752","553207508","154784232","199176676"],
|
||||
["689603315","1763523007","1720552945","1983603154","367841669","319325418"],
|
||||
["1290247052","1120744584","193500372","294491115","951360807","891034447"]],
|
||||
|
||||
"proof of work" : "43",
|
||||
|
||||
"coeffs" : ["329725079","667313404","2083859876","1645693780"],
|
||||
|
||||
"inner_commitment_0" : "45555755014923146766476222823122654194153582923386372098787021809602091298670",
|
||||
|
||||
"inner_decommitment_0" :
|
||||
["5462985033728555575703006689913665598917262836577853690370070265073174719979",
|
||||
"1439463537898028163672031322449473184361454650351831891678420729829634422052",
|
||||
"3457992889375232257419443221173459860069710025424105705342440558952480618733",
|
||||
"1560236756242436900530359200127475252176167690738804367783218449678388748008",
|
||||
"44704035195642947074853484428073358106137689624692539529545008114744235429869",
|
||||
"1823722396825684657353300460130959150013626547749061669644787096209462088081",
|
||||
"36173961213090661587166983586606381085726394354099487131775975739295244107044",
|
||||
"44982414287882497869227873977178787463294164719732705041425254748922385395330",
|
||||
"45309152665928987599895709328729600158686240484613667382186228529382344291065",
|
||||
"6727671460449390719545303211873485738734289534911560838410979450236173250575",
|
||||
"20261836592905556803606993653450529841562498576647326913889589677099044443740",
|
||||
"5092741370842944151567575040255811031129982749889308756570945791769905531735",
|
||||
"25669118913473215056777887930215946759071827607816131638248897691395809170466"],
|
||||
|
||||
"inner_evals_subset_0" :
|
||||
[["1779738283","1440487135","622229563","2027928845"],
|
||||
["3310770","2003547458","1663490902","2105455978"],
|
||||
["824310523","1757518542","231582441","427507918"]],
|
||||
|
||||
"inner_commitment_1" : "36809196688918736785151875655523363274066842107451407514854169291514772437712",
|
||||
|
||||
"inner_decommitment_1" :
|
||||
["45836684762941279488847688946861562185184567552524644394596887832754938056979",
|
||||
"19073399120361742411025793373596546929954434607182120440274945820526786807031",
|
||||
"10149408627738268983458395386544516732232569888429637086606165543488301475287",
|
||||
"1134561121197915913438467336074096605141057500441242329860673894339623319838",
|
||||
"4117331370612733116088255919125911273783755080775248279590234839282346655209",
|
||||
"10513576479117646185507147947418059201301955276798193127377481002238952180861",
|
||||
"23342730846670478933566004300392843380760787769231395287565422468403932432226",
|
||||
"18740447275326464369699774196183138103532792867824812360537906002797944316369",
|
||||
"14414140432147963417180883803716587948502362440108737731008978168633980398219",
|
||||
"12250577863311462987061070846596950168057013378171600204130942328863618210853"],
|
||||
|
||||
"inner_evals_subset_1" :
|
||||
[["1993244979","993480712","1202910330","51538179"],
|
||||
["1550507975","1444313548","117070947","1740854590"],
|
||||
["342709844","601149328","1436490544","1384381104"]],
|
||||
|
||||
"inner_commitment_2" : "881819876414071785116043072319901261019430342891967010427739931379717752179",
|
||||
|
||||
"inner_decommitment_2" :
|
||||
["3578061893060038632121836895066391994380789049478144565795630968916166958170",
|
||||
"37159117331259094338569510749131708143302385613991325356583099983565421599794",
|
||||
"37757292341028790385725896863222501721662517262004109998718347744172059822092",
|
||||
"32039058093629437909283983486827138804118789892152561698628703527746673098335",
|
||||
"11826129489683399268430621966121477905203281789801973078750629518404118320344",
|
||||
"12209931412475259378582945309327673629248630080862144675946167350377528046777",
|
||||
"23652788355891929123107872712785512551735081591713325931570432712643436668697"],
|
||||
|
||||
"inner_evals_subset_2" :
|
||||
[["1184374976","203117688","803515935","1781630737"],
|
||||
["716686867","138132852","2024080584","392488646"],
|
||||
["606958913","308986056","258114411","2075401741"]],
|
||||
|
||||
"inner_commitment_3" : "12027868599227153144742193285247060272784688895537159385948117930165143367635",
|
||||
|
||||
"inner_decommitment_3" :
|
||||
["46510393127320994984678433868779353751380750916123221099220042551249644407575",
|
||||
"39205406309151711272632862117612882165913578213907083849769292791373274291092",
|
||||
"1794620160371318211300608665903463515892826349097148444913826356018179256828",
|
||||
"30764929521910551485060681298929856016753563990361023701791813032517030674324"],
|
||||
|
||||
"inner_evals_subset_3" :
|
||||
[["1682558787","129089003","784689440","491206249"],
|
||||
["2038210709","1600238918","655676259","1542271403"],
|
||||
["1014579052","1384080403","862591487","1941843578"]],
|
||||
|
||||
"inner_commitment_4" : "14412699124489796400221638504502701429059474709940645969751865021865037702257",
|
||||
|
||||
"inner_decommitment_4" :
|
||||
["19324767949149751902195880760061491860991545124249052461044586503970003688610"],
|
||||
|
||||
"inner_evals_subset_4" :
|
||||
[["683258805","1002722262","1583421272","1748673499"],
|
||||
["2101847208","689925082","1602280602","1942656531"],
|
||||
["1952987285","1995490213","2082219584","1620868519"]],
|
||||
|
||||
"inner_commitment_5" : "7898658461322497542615494384418597990320440781916208640366283709415937814000",
|
||||
|
||||
"inner_decommitment_5" :
|
||||
[],
|
||||
|
||||
"inner_evals_subset_5" :
|
||||
[["1862940297","2014478284","1043383827","560191545"]]
|
||||
}
|
||||
@ -0,0 +1,84 @@
|
||||
use num_traits::{One, Zero};
|
||||
|
||||
use super::EvalAtRow;
|
||||
use crate::core::backend::{Backend, Column};
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
|
||||
use crate::core::pcs::TreeVec;
|
||||
use crate::core::poly::circle::{CanonicCoset, CirclePoly};
|
||||
use crate::core::utils::circle_domain_order_to_coset_order;
|
||||
|
||||
/// Evaluates expressions at a trace domain row, and asserts constraints. Mainly used for testing.
|
||||
pub struct AssertEvaluator<'a> {
|
||||
pub trace: &'a TreeVec<Vec<Vec<BaseField>>>,
|
||||
pub col_index: TreeVec<usize>,
|
||||
pub row: usize,
|
||||
}
|
||||
impl<'a> AssertEvaluator<'a> {
|
||||
pub fn new(trace: &'a TreeVec<Vec<Vec<BaseField>>>, row: usize) -> Self {
|
||||
Self {
|
||||
trace,
|
||||
col_index: TreeVec::new(vec![0; trace.len()]),
|
||||
row,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<'a> EvalAtRow for AssertEvaluator<'a> {
|
||||
type F = BaseField;
|
||||
type EF = SecureField;
|
||||
|
||||
fn next_interaction_mask<const N: usize>(
|
||||
&mut self,
|
||||
interaction: usize,
|
||||
offsets: [isize; N],
|
||||
) -> [Self::F; N] {
|
||||
let col_index = self.col_index[interaction];
|
||||
self.col_index[interaction] += 1;
|
||||
offsets.map(|off| {
|
||||
// The mask row might wrap around the column size.
|
||||
let col_size = self.trace[interaction][col_index].len() as isize;
|
||||
self.trace[interaction][col_index]
|
||||
[(self.row as isize + off).rem_euclid(col_size) as usize]
|
||||
})
|
||||
}
|
||||
|
||||
fn add_constraint<G>(&mut self, constraint: G)
|
||||
where
|
||||
Self::EF: std::ops::Mul<G, Output = Self::EF>,
|
||||
{
|
||||
// Cast to SecureField.
|
||||
let res = SecureField::one() * constraint;
|
||||
// The constraint should be zero at the given row, since we are evaluating on the trace
|
||||
// domain.
|
||||
assert_eq!(res, SecureField::zero(), "row: {}", self.row);
|
||||
}
|
||||
|
||||
fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF {
|
||||
SecureField::from_m31_array(values)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn assert_constraints<B: Backend>(
|
||||
trace_polys: &TreeVec<Vec<CirclePoly<B>>>,
|
||||
trace_domain: CanonicCoset,
|
||||
assert_func: impl Fn(AssertEvaluator<'_>),
|
||||
) {
|
||||
let traces = trace_polys.as_ref().map(|tree| {
|
||||
tree.iter()
|
||||
.map(|poly| {
|
||||
circle_domain_order_to_coset_order(
|
||||
&poly
|
||||
.evaluate(trace_domain.circle_domain())
|
||||
.bit_reverse()
|
||||
.values
|
||||
.to_cpu(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
for row in 0..trace_domain.size() {
|
||||
let eval = AssertEvaluator::new(&traces, row);
|
||||
assert_func(eval);
|
||||
}
|
||||
}
|
||||
210
Stwo_wrapper/crates/prover/src/constraint_framework/component.rs
Normal file
210
Stwo_wrapper/crates/prover/src/constraint_framework/component.rs
Normal file
@ -0,0 +1,210 @@
|
||||
use std::borrow::Cow;
|
||||
use std::iter::zip;
|
||||
use std::ops::Deref;
|
||||
|
||||
use itertools::Itertools;
|
||||
use tracing::{span, Level};
|
||||
|
||||
use super::{EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator};
|
||||
use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
|
||||
use crate::core::air::{Component, ComponentProver, Trace};
|
||||
use crate::core::backend::simd::column::VeryPackedSecureColumnByCoords;
|
||||
use crate::core::backend::simd::m31::LOG_N_LANES;
|
||||
use crate::core::backend::simd::very_packed_m31::{VeryPackedBaseField, LOG_N_VERY_PACKED_ELEMS};
|
||||
use crate::core::backend::simd::SimdBackend;
|
||||
use crate::core::circle::CirclePoint;
|
||||
use crate::core::constraints::coset_vanishing;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::fields::FieldExpOps;
|
||||
use crate::core::pcs::{TreeSubspan, TreeVec};
|
||||
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps};
|
||||
use crate::core::poly::BitReversedOrder;
|
||||
use crate::core::{utils, ColumnVec};
|
||||
|
||||
// TODO(andrew): Docs.
|
||||
// TODO(andrew): Consider better location for this.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct TraceLocationAllocator {
|
||||
/// Mapping of tree index to next available column offset.
|
||||
next_tree_offsets: TreeVec<usize>,
|
||||
}
|
||||
|
||||
impl TraceLocationAllocator {
|
||||
fn next_for_structure<T>(&mut self, structure: &TreeVec<ColumnVec<T>>) -> TreeVec<TreeSubspan> {
|
||||
if structure.len() > self.next_tree_offsets.len() {
|
||||
self.next_tree_offsets.resize(structure.len(), 0);
|
||||
}
|
||||
|
||||
TreeVec::new(
|
||||
zip(&mut *self.next_tree_offsets, &**structure)
|
||||
.enumerate()
|
||||
.map(|(tree_index, (offset, cols))| {
|
||||
let col_start = *offset;
|
||||
let col_end = col_start + cols.len();
|
||||
*offset = col_end;
|
||||
TreeSubspan {
|
||||
tree_index,
|
||||
col_start,
|
||||
col_end,
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// A component defined solely in means of the constraints framework.
|
||||
/// Implementing this trait introduces implementations for [`Component`] and [`ComponentProver`] for
|
||||
/// the SIMD backend.
|
||||
/// Note that the constraint framework only support components with columns of the same size.
|
||||
pub trait FrameworkEval {
|
||||
fn log_size(&self) -> u32;
|
||||
|
||||
fn max_constraint_log_degree_bound(&self) -> u32;
|
||||
|
||||
fn evaluate<E: EvalAtRow>(&self, eval: E) -> E;
|
||||
}
|
||||
|
||||
pub struct FrameworkComponent<C: FrameworkEval> {
|
||||
eval: C,
|
||||
trace_locations: TreeVec<TreeSubspan>,
|
||||
}
|
||||
|
||||
impl<E: FrameworkEval> FrameworkComponent<E> {
|
||||
pub fn new(provider: &mut TraceLocationAllocator, eval: E) -> Self {
|
||||
let eval_tree_structure = eval.evaluate(InfoEvaluator::default()).mask_offsets;
|
||||
let trace_locations = provider.next_for_structure(&eval_tree_structure);
|
||||
Self {
|
||||
eval,
|
||||
trace_locations,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: FrameworkEval> Component for FrameworkComponent<E> {
|
||||
fn n_constraints(&self) -> usize {
|
||||
self.eval.evaluate(InfoEvaluator::default()).n_constraints
|
||||
}
|
||||
|
||||
fn max_constraint_log_degree_bound(&self) -> u32 {
|
||||
self.eval.max_constraint_log_degree_bound()
|
||||
}
|
||||
|
||||
fn trace_log_degree_bounds(&self) -> TreeVec<ColumnVec<u32>> {
|
||||
TreeVec::new(
|
||||
self.eval
|
||||
.evaluate(InfoEvaluator::default())
|
||||
.mask_offsets
|
||||
.iter()
|
||||
.map(|tree_masks| vec![self.eval.log_size(); tree_masks.len()])
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
fn mask_points(
|
||||
&self,
|
||||
point: CirclePoint<SecureField>,
|
||||
) -> TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>> {
|
||||
let info = self.eval.evaluate(InfoEvaluator::default());
|
||||
let trace_step = CanonicCoset::new(self.eval.log_size()).step();
|
||||
info.mask_offsets.map_cols(|col_mask| {
|
||||
col_mask
|
||||
.iter()
|
||||
.map(|off| point + trace_step.mul_signed(*off).into_ef())
|
||||
.collect()
|
||||
})
|
||||
}
|
||||
|
||||
fn evaluate_constraint_quotients_at_point(
|
||||
&self,
|
||||
point: CirclePoint<SecureField>,
|
||||
mask: &TreeVec<ColumnVec<Vec<SecureField>>>,
|
||||
evaluation_accumulator: &mut PointEvaluationAccumulator,
|
||||
) {
|
||||
self.eval.evaluate(PointEvaluator::new(
|
||||
mask.sub_tree(&self.trace_locations),
|
||||
evaluation_accumulator,
|
||||
coset_vanishing(CanonicCoset::new(self.eval.log_size()).coset, point).inverse(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: FrameworkEval> ComponentProver<SimdBackend> for FrameworkComponent<E> {
|
||||
fn evaluate_constraint_quotients_on_domain(
|
||||
&self,
|
||||
trace: &Trace<'_, SimdBackend>,
|
||||
evaluation_accumulator: &mut DomainEvaluationAccumulator<SimdBackend>,
|
||||
) {
|
||||
let eval_domain = CanonicCoset::new(self.max_constraint_log_degree_bound()).circle_domain();
|
||||
let trace_domain = CanonicCoset::new(self.eval.log_size());
|
||||
|
||||
let component_polys = trace.polys.sub_tree(&self.trace_locations);
|
||||
let component_evals = trace.evals.sub_tree(&self.trace_locations);
|
||||
|
||||
// Extend trace if necessary.
|
||||
// TODO(spapini): Don't extend when eval_size < committed_size. Instead, pick a good
|
||||
// subdomain.
|
||||
let need_to_extend = component_evals
|
||||
.iter()
|
||||
.flatten()
|
||||
.any(|c| c.domain != eval_domain);
|
||||
let trace: TreeVec<
|
||||
Vec<Cow<'_, CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>>,
|
||||
> = if need_to_extend {
|
||||
let _span = span!(Level::INFO, "Extension").entered();
|
||||
let twiddles = SimdBackend::precompute_twiddles(eval_domain.half_coset);
|
||||
component_polys
|
||||
.as_cols_ref()
|
||||
.map_cols(|col| Cow::Owned(col.evaluate_with_twiddles(eval_domain, &twiddles)))
|
||||
} else {
|
||||
component_evals.clone().map_cols(|c| Cow::Borrowed(*c))
|
||||
};
|
||||
|
||||
// Denom inverses.
|
||||
let log_expand = eval_domain.log_size() - trace_domain.log_size();
|
||||
let mut denom_inv = (0..1 << log_expand)
|
||||
.map(|i| coset_vanishing(trace_domain.coset(), eval_domain.at(i)).inverse())
|
||||
.collect_vec();
|
||||
utils::bit_reverse(&mut denom_inv);
|
||||
|
||||
// Accumulator.
|
||||
let [mut accum] =
|
||||
evaluation_accumulator.columns([(eval_domain.log_size(), self.n_constraints())]);
|
||||
accum.random_coeff_powers.reverse();
|
||||
|
||||
let _span = span!(Level::INFO, "Constraint pointwise eval").entered();
|
||||
let col = unsafe { VeryPackedSecureColumnByCoords::transform_under_mut(accum.col) };
|
||||
|
||||
for vec_row in 0..(1 << (eval_domain.log_size() - LOG_N_LANES - LOG_N_VERY_PACKED_ELEMS)) {
|
||||
let trace_cols = trace.as_cols_ref().map_cols(|c| c.as_ref());
|
||||
|
||||
// Evaluate constrains at row.
|
||||
let eval = SimdDomainEvaluator::new(
|
||||
&trace_cols,
|
||||
vec_row,
|
||||
&accum.random_coeff_powers,
|
||||
trace_domain.log_size(),
|
||||
eval_domain.log_size(),
|
||||
);
|
||||
let row_res = self.eval.evaluate(eval).row_res;
|
||||
|
||||
// Finalize row.
|
||||
unsafe {
|
||||
let denom_inv = VeryPackedBaseField::broadcast(
|
||||
denom_inv[vec_row
|
||||
>> (trace_domain.log_size() - LOG_N_LANES - LOG_N_VERY_PACKED_ELEMS)],
|
||||
);
|
||||
col.set_packed(vec_row, col.packed_at(vec_row) + row_res * denom_inv)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: FrameworkEval> Deref for FrameworkComponent<E> {
|
||||
type Target = E;
|
||||
|
||||
fn deref(&self) -> &E {
|
||||
&self.eval
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,37 @@
|
||||
use num_traits::One;
|
||||
|
||||
use crate::core::backend::{Backend, Col, Column};
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
|
||||
use crate::core::poly::BitReversedOrder;
|
||||
use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index};
|
||||
|
||||
/// Generates a column with a single one at the first position, and zeros elsewhere.
|
||||
pub fn gen_is_first<B: Backend>(log_size: u32) -> CircleEvaluation<B, BaseField, BitReversedOrder> {
|
||||
let mut col = Col::<B, BaseField>::zeros(1 << log_size);
|
||||
col.set(0, BaseField::one());
|
||||
CircleEvaluation::new(CanonicCoset::new(log_size).circle_domain(), col)
|
||||
}
|
||||
|
||||
/// Generates a column with `1` at every `2^log_step` positions, `0` elsewhere, shifted by offset.
|
||||
// TODO(andrew): Consider optimizing. Is a quotients of two coset_vanishing (use succinct rep for
|
||||
// verifier).
|
||||
pub fn gen_is_step_with_offset<B: Backend>(
|
||||
log_size: u32,
|
||||
log_step: u32,
|
||||
offset: usize,
|
||||
) -> CircleEvaluation<B, BaseField, BitReversedOrder> {
|
||||
let mut col = Col::<B, BaseField>::zeros(1 << log_size);
|
||||
|
||||
let size = 1 << log_size;
|
||||
let step = 1 << log_step;
|
||||
let step_offset = offset % step;
|
||||
|
||||
for i in (step_offset..size).step_by(step) {
|
||||
let circle_domain_index = coset_index_to_circle_domain_index(i, log_size);
|
||||
let circle_domain_index_bit_rev = bit_reverse_index(circle_domain_index, log_size);
|
||||
col.set(circle_domain_index_bit_rev, BaseField::one());
|
||||
}
|
||||
|
||||
CircleEvaluation::new(CanonicCoset::new(log_size).circle_domain(), col)
|
||||
}
|
||||
48
Stwo_wrapper/crates/prover/src/constraint_framework/info.rs
Normal file
48
Stwo_wrapper/crates/prover/src/constraint_framework/info.rs
Normal file
@ -0,0 +1,48 @@
|
||||
use std::ops::Mul;
|
||||
|
||||
use num_traits::One;
|
||||
|
||||
use super::EvalAtRow;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::pcs::TreeVec;
|
||||
|
||||
/// Collects information about the constraints.
|
||||
/// This includes mask offsets and columns at each interaction, and the number of constraints.
|
||||
#[derive(Default)]
|
||||
pub struct InfoEvaluator {
|
||||
pub mask_offsets: TreeVec<Vec<Vec<isize>>>,
|
||||
pub n_constraints: usize,
|
||||
}
|
||||
impl InfoEvaluator {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
}
|
||||
impl EvalAtRow for InfoEvaluator {
|
||||
type F = BaseField;
|
||||
type EF = SecureField;
|
||||
fn next_interaction_mask<const N: usize>(
|
||||
&mut self,
|
||||
interaction: usize,
|
||||
offsets: [isize; N],
|
||||
) -> [Self::F; N] {
|
||||
// Check if requested a mask from a new interaction
|
||||
if self.mask_offsets.len() <= interaction {
|
||||
// Extend `mask_offsets` so that `interaction` is the last index.
|
||||
self.mask_offsets.resize(interaction + 1, vec![]);
|
||||
}
|
||||
self.mask_offsets[interaction].push(offsets.into_iter().collect());
|
||||
[BaseField::one(); N]
|
||||
}
|
||||
fn add_constraint<G>(&mut self, _constraint: G)
|
||||
where
|
||||
Self::EF: Mul<G, Output = Self::EF>,
|
||||
{
|
||||
self.n_constraints += 1;
|
||||
}
|
||||
|
||||
fn combine_ef(_values: [Self::F; 4]) -> Self::EF {
|
||||
SecureField::one()
|
||||
}
|
||||
}
|
||||
315
Stwo_wrapper/crates/prover/src/constraint_framework/logup.rs
Normal file
315
Stwo_wrapper/crates/prover/src/constraint_framework/logup.rs
Normal file
@ -0,0 +1,315 @@
|
||||
use std::ops::{Mul, Sub};
|
||||
|
||||
use itertools::Itertools;
|
||||
use num_traits::{One, Zero};
|
||||
|
||||
use super::EvalAtRow;
|
||||
use crate::core::backend::simd::column::SecureColumn;
|
||||
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
|
||||
use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum;
|
||||
use crate::core::backend::simd::qm31::PackedSecureField;
|
||||
use crate::core::backend::simd::SimdBackend;
|
||||
use crate::core::backend::Column;
|
||||
use crate::core::channel::Channel;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::fields::secure_column::{SecureColumnByCoords, SECURE_EXTENSION_DEGREE};
|
||||
use crate::core::fields::FieldExpOps;
|
||||
use crate::core::lookups::utils::Fraction;
|
||||
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
|
||||
use crate::core::poly::BitReversedOrder;
|
||||
use crate::core::ColumnVec;
|
||||
|
||||
/// Evaluates constraints for batched logups.
|
||||
/// These constraint enforce the sum of multiplicity_i / (z + sum_j alpha^j * x_j) = claimed_sum.
|
||||
/// BATCH_SIZE is the number of fractions to batch together. The degree of the resulting constraints
|
||||
/// will be BATCH_SIZE + 1.
|
||||
pub struct LogupAtRow<const BATCH_SIZE: usize, E: EvalAtRow> {
|
||||
/// The index of the interaction used for the cumulative sum columns.
|
||||
pub interaction: usize,
|
||||
/// Queue of fractions waiting to be batched together.
|
||||
pub queue: [(E::EF, E::EF); BATCH_SIZE],
|
||||
/// Number of fractions in the queue.
|
||||
pub queue_size: usize,
|
||||
/// A constant to subtract from each row, to make the totall sum of the last column zero.
|
||||
/// In other words, claimed_sum / 2^log_size.
|
||||
/// This is used to make the constraint uniform.
|
||||
pub cumsum_shift: SecureField,
|
||||
/// The evaluation of the last cumulative sum column.
|
||||
pub prev_col_cumsum: E::EF,
|
||||
is_finalized: bool,
|
||||
}
|
||||
impl<const BATCH_SIZE: usize, E: EvalAtRow> LogupAtRow<BATCH_SIZE, E> {
|
||||
pub fn new(interaction: usize, claimed_sum: SecureField, log_size: u32) -> Self {
|
||||
Self {
|
||||
interaction,
|
||||
queue: [(E::EF::zero(), E::EF::zero()); BATCH_SIZE],
|
||||
queue_size: 0,
|
||||
cumsum_shift: claimed_sum / BaseField::from_u32_unchecked(1 << log_size),
|
||||
prev_col_cumsum: E::EF::zero(),
|
||||
is_finalized: false,
|
||||
}
|
||||
}
|
||||
pub fn push_lookup<const N: usize>(
|
||||
&mut self,
|
||||
eval: &mut E,
|
||||
numerator: E::EF,
|
||||
values: &[E::F],
|
||||
lookup_elements: &LookupElements<N>,
|
||||
) {
|
||||
let shifted_value = lookup_elements.combine(values);
|
||||
self.push_frac(eval, numerator, shifted_value);
|
||||
}
|
||||
|
||||
pub fn push_frac(&mut self, eval: &mut E, numerator: E::EF, denominator: E::EF) {
|
||||
if self.queue_size < BATCH_SIZE {
|
||||
self.queue[self.queue_size] = (numerator, denominator);
|
||||
self.queue_size += 1;
|
||||
return;
|
||||
}
|
||||
|
||||
// Compute sum_i pi/qi over batch, as a fraction, num/denom.
|
||||
let (num, denom) = self.fold_queue();
|
||||
|
||||
self.queue[0] = (numerator, denominator);
|
||||
self.queue_size = 1;
|
||||
|
||||
// Add a constraint that num / denom = diff.
|
||||
let cur_cumsum = eval.next_extension_interaction_mask(self.interaction, [0])[0];
|
||||
let diff = cur_cumsum - self.prev_col_cumsum;
|
||||
self.prev_col_cumsum = cur_cumsum;
|
||||
eval.add_constraint(diff * denom - num);
|
||||
}
|
||||
|
||||
pub fn add_frac(&mut self, eval: &mut E, fraction: Fraction<E::EF, E::EF>) {
|
||||
// Add a constraint that num / denom = diff.
|
||||
let cur_cumsum = eval.next_extension_interaction_mask(self.interaction, [0])[0];
|
||||
let diff = cur_cumsum - self.prev_col_cumsum;
|
||||
self.prev_col_cumsum = cur_cumsum;
|
||||
eval.add_constraint(diff * fraction.denominator - fraction.numerator);
|
||||
}
|
||||
|
||||
pub fn finalize(mut self, eval: &mut E) {
|
||||
assert!(!self.is_finalized, "LogupAtRow was already finalized");
|
||||
let (num, denom) = self.fold_queue();
|
||||
|
||||
let [cur_cumsum, prev_row_cumsum] =
|
||||
eval.next_extension_interaction_mask(self.interaction, [0, -1]);
|
||||
|
||||
let diff = cur_cumsum - prev_row_cumsum - self.prev_col_cumsum;
|
||||
// Instead of checking diff = num / denom, check diff = num / denom - cumsum_shift.
|
||||
// This makes (num / denom - cumsum_shift) have sum zero, which makes the constraint
|
||||
// uniform - apply on all rows.
|
||||
let fixed_diff = diff + self.cumsum_shift;
|
||||
|
||||
eval.add_constraint(fixed_diff * denom - num);
|
||||
|
||||
self.is_finalized = true;
|
||||
}
|
||||
|
||||
fn fold_queue(&self) -> (E::EF, E::EF) {
|
||||
self.queue[0..self.queue_size]
|
||||
.iter()
|
||||
.copied()
|
||||
.fold((E::EF::zero(), E::EF::one()), |(p0, q0), (pi, qi)| {
|
||||
(p0 * qi + pi * q0, qi * q0)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Ensures that the LogupAtRow is finalized.
|
||||
/// LogupAtRow should be finalized exactly once.
|
||||
impl<const BATCH_SIZE: usize, E: EvalAtRow> Drop for LogupAtRow<BATCH_SIZE, E> {
|
||||
fn drop(&mut self) {
|
||||
assert!(self.is_finalized, "LogupAtRow was not finalized");
|
||||
}
|
||||
}
|
||||
|
||||
/// Interaction elements for the logup protocol.
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct LookupElements<const N: usize> {
|
||||
pub z: SecureField,
|
||||
pub alpha: SecureField,
|
||||
alpha_powers: [SecureField; N],
|
||||
}
|
||||
impl<const N: usize> LookupElements<N> {
|
||||
pub fn draw(channel: &mut impl Channel) -> Self {
|
||||
let [z, alpha] = channel.draw_felts(2).try_into().unwrap();
|
||||
let mut cur = SecureField::one();
|
||||
let alpha_powers = std::array::from_fn(|_| {
|
||||
let res = cur;
|
||||
cur *= alpha;
|
||||
res
|
||||
});
|
||||
Self {
|
||||
z,
|
||||
alpha,
|
||||
alpha_powers,
|
||||
}
|
||||
}
|
||||
pub fn combine<F: Copy, EF>(&self, values: &[F]) -> EF
|
||||
where
|
||||
EF: Copy + Zero + From<F> + From<SecureField> + Mul<F, Output = EF> + Sub<EF, Output = EF>,
|
||||
{
|
||||
EF::from(values[0])
|
||||
+ values[1..]
|
||||
.iter()
|
||||
.zip(self.alpha_powers.iter())
|
||||
.fold(EF::zero(), |acc, (&value, &power)| {
|
||||
acc + EF::from(power) * value
|
||||
})
|
||||
- EF::from(self.z)
|
||||
}
|
||||
// TODO(spapini): Try to remove this.
|
||||
pub fn dummy() -> Self {
|
||||
Self {
|
||||
z: SecureField::one(),
|
||||
alpha: SecureField::one(),
|
||||
alpha_powers: [SecureField::one(); N],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SIMD backend generator for logup interaction trace.
|
||||
pub struct LogupTraceGenerator {
|
||||
log_size: u32,
|
||||
/// Current allocated interaction columns.
|
||||
trace: Vec<SecureColumnByCoords<SimdBackend>>,
|
||||
/// Denominator expressions (z + sum_i alpha^i * x_i) being generated for the current lookup.
|
||||
denom: SecureColumn,
|
||||
/// Preallocated buffer for the Inverses of the denominators.
|
||||
denom_inv: SecureColumn,
|
||||
}
|
||||
impl LogupTraceGenerator {
|
||||
pub fn new(log_size: u32) -> Self {
|
||||
let trace = vec![];
|
||||
let denom = SecureColumn::zeros(1 << log_size);
|
||||
let denom_inv = SecureColumn::zeros(1 << log_size);
|
||||
Self {
|
||||
log_size,
|
||||
trace,
|
||||
denom,
|
||||
denom_inv,
|
||||
}
|
||||
}
|
||||
|
||||
/// Allocate a new lookup column.
|
||||
pub fn new_col(&mut self) -> LogupColGenerator<'_> {
|
||||
let log_size = self.log_size;
|
||||
LogupColGenerator {
|
||||
gen: self,
|
||||
numerator: SecureColumnByCoords::<SimdBackend>::zeros(1 << log_size),
|
||||
}
|
||||
}
|
||||
|
||||
/// Finalize the trace. Returns the trace and the claimed sum of the last column.
|
||||
pub fn finalize(
|
||||
mut self,
|
||||
) -> (
|
||||
ColumnVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>,
|
||||
SecureField,
|
||||
) {
|
||||
// Compute claimed sum.
|
||||
let mut last_col_coords = self.trace.pop().unwrap().columns;
|
||||
let packed_sums: [PackedBaseField; SECURE_EXTENSION_DEGREE] = last_col_coords
|
||||
.each_ref()
|
||||
.map(|c| c.data.iter().copied().sum());
|
||||
let base_sums = packed_sums.map(|s| s.pointwise_sum());
|
||||
let claimed_sum = SecureField::from_m31_array(base_sums);
|
||||
|
||||
// Shift the last column to make the sum zero.
|
||||
let cumsum_shift = claimed_sum / BaseField::from_u32_unchecked(1 << self.log_size);
|
||||
last_col_coords.iter_mut().enumerate().for_each(|(i, c)| {
|
||||
c.data
|
||||
.iter_mut()
|
||||
.for_each(|x| *x -= PackedBaseField::broadcast(cumsum_shift.to_m31_array()[i]))
|
||||
});
|
||||
|
||||
// Prefix sum the last column.
|
||||
let coord_prefix_sum = last_col_coords.map(inclusive_prefix_sum);
|
||||
self.trace.push(SecureColumnByCoords {
|
||||
columns: coord_prefix_sum,
|
||||
});
|
||||
|
||||
let trace = self
|
||||
.trace
|
||||
.into_iter()
|
||||
.flat_map(|eval| {
|
||||
eval.columns.map(|c| {
|
||||
CircleEvaluation::<SimdBackend, _, BitReversedOrder>::new(
|
||||
CanonicCoset::new(self.log_size).circle_domain(),
|
||||
c,
|
||||
)
|
||||
})
|
||||
})
|
||||
.collect_vec();
|
||||
(trace, claimed_sum)
|
||||
}
|
||||
}
|
||||
|
||||
/// Trace generator for a single lookup column.
|
||||
pub struct LogupColGenerator<'a> {
|
||||
gen: &'a mut LogupTraceGenerator,
|
||||
/// Numerator expressions (i.e. multiplicities) being generated for the current lookup.
|
||||
numerator: SecureColumnByCoords<SimdBackend>,
|
||||
}
|
||||
impl<'a> LogupColGenerator<'a> {
|
||||
/// Write a fraction to the column at a row.
|
||||
pub fn write_frac(
|
||||
&mut self,
|
||||
vec_row: usize,
|
||||
numerator: PackedSecureField,
|
||||
denom: PackedSecureField,
|
||||
) {
|
||||
debug_assert!(
|
||||
denom.to_array().iter().all(|x| *x != SecureField::zero()),
|
||||
"{:?}",
|
||||
("denom at vec_row {} is zero {}", denom, vec_row)
|
||||
);
|
||||
unsafe {
|
||||
self.numerator.set_packed(vec_row, numerator);
|
||||
*self.gen.denom.data.get_unchecked_mut(vec_row) = denom;
|
||||
}
|
||||
}
|
||||
|
||||
/// Finalizes generating the column.
|
||||
pub fn finalize_col(mut self) {
|
||||
FieldExpOps::batch_inverse(&self.gen.denom.data, &mut self.gen.denom_inv.data);
|
||||
|
||||
for vec_row in 0..(1 << (self.gen.log_size - LOG_N_LANES)) {
|
||||
unsafe {
|
||||
let value = self.numerator.packed_at(vec_row)
|
||||
* *self.gen.denom_inv.data.get_unchecked(vec_row);
|
||||
let prev_value = self
|
||||
.gen
|
||||
.trace
|
||||
.last()
|
||||
.map(|col| col.packed_at(vec_row))
|
||||
.unwrap_or_else(PackedSecureField::zero);
|
||||
self.numerator.set_packed(vec_row, value + prev_value)
|
||||
};
|
||||
}
|
||||
|
||||
self.gen.trace.push(self.numerator)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use num_traits::One;
|
||||
|
||||
use super::LogupAtRow;
|
||||
use crate::constraint_framework::InfoEvaluator;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_logup_not_finalized_panic() {
|
||||
let mut logup = LogupAtRow::<2, InfoEvaluator>::new(1, SecureField::one(), 7);
|
||||
logup.push_frac(
|
||||
&mut InfoEvaluator::default(),
|
||||
SecureField::one(),
|
||||
SecureField::one(),
|
||||
);
|
||||
}
|
||||
}
|
||||
97
Stwo_wrapper/crates/prover/src/constraint_framework/mod.rs
Normal file
97
Stwo_wrapper/crates/prover/src/constraint_framework/mod.rs
Normal file
@ -0,0 +1,97 @@
|
||||
/// ! This module contains helpers to express and use constraints for components.
|
||||
mod assert;
|
||||
mod component;
|
||||
pub mod constant_columns;
|
||||
mod info;
|
||||
pub mod logup;
|
||||
mod point;
|
||||
mod simd_domain;
|
||||
|
||||
use std::array;
|
||||
use std::fmt::Debug;
|
||||
use std::ops::{Add, AddAssign, Mul, Neg, Sub};
|
||||
|
||||
pub use assert::{assert_constraints, AssertEvaluator};
|
||||
pub use component::{FrameworkComponent, FrameworkEval, TraceLocationAllocator};
|
||||
pub use info::InfoEvaluator;
|
||||
use num_traits::{One, Zero};
|
||||
pub use point::PointEvaluator;
|
||||
pub use simd_domain::SimdDomainEvaluator;
|
||||
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
|
||||
use crate::core::fields::FieldExpOps;
|
||||
|
||||
/// A trait for evaluating expressions at some point or row.
|
||||
pub trait EvalAtRow {
|
||||
// TODO(spapini): Use a better trait for these, like 'Algebra' or something.
|
||||
/// The field type holding values of columns for the component. These are the inputs to the
|
||||
/// constraints. It might be [BaseField] packed types, or even [SecureField], when evaluating
|
||||
/// the columns out of domain.
|
||||
type F: FieldExpOps
|
||||
+ Copy
|
||||
+ Debug
|
||||
+ Zero
|
||||
+ Neg<Output = Self::F>
|
||||
+ AddAssign
|
||||
+ AddAssign<BaseField>
|
||||
+ Add<Self::F, Output = Self::F>
|
||||
+ Sub<Self::F, Output = Self::F>
|
||||
+ Mul<BaseField, Output = Self::F>
|
||||
+ Add<SecureField, Output = Self::EF>
|
||||
+ Mul<SecureField, Output = Self::EF>
|
||||
+ Neg<Output = Self::F>
|
||||
+ From<BaseField>;
|
||||
|
||||
/// A field type representing the closure of `F` with multiplying by [SecureField]. Constraints
|
||||
/// usually get multiplied by [SecureField] values for security.
|
||||
type EF: One
|
||||
+ Copy
|
||||
+ Debug
|
||||
+ Zero
|
||||
+ From<Self::F>
|
||||
+ Neg<Output = Self::EF>
|
||||
+ AddAssign
|
||||
+ Add<SecureField, Output = Self::EF>
|
||||
+ Sub<SecureField, Output = Self::EF>
|
||||
+ Mul<SecureField, Output = Self::EF>
|
||||
+ Add<Self::F, Output = Self::EF>
|
||||
+ Mul<Self::F, Output = Self::EF>
|
||||
+ Sub<Self::EF, Output = Self::EF>
|
||||
+ Mul<Self::EF, Output = Self::EF>
|
||||
+ From<SecureField>
|
||||
+ From<Self::F>;
|
||||
|
||||
/// Returns the next mask value for the first interaction at offset 0.
|
||||
fn next_trace_mask(&mut self) -> Self::F {
|
||||
let [mask_item] = self.next_interaction_mask(0, [0]);
|
||||
mask_item
|
||||
}
|
||||
|
||||
/// Returns the mask values of the given offsets for the next column in the interaction.
|
||||
fn next_interaction_mask<const N: usize>(
|
||||
&mut self,
|
||||
interaction: usize,
|
||||
offsets: [isize; N],
|
||||
) -> [Self::F; N];
|
||||
|
||||
/// Returns the extension mask values of the given offsets for the next extension degree many
|
||||
/// columns in the interaction.
|
||||
fn next_extension_interaction_mask<const N: usize>(
|
||||
&mut self,
|
||||
interaction: usize,
|
||||
offsets: [isize; N],
|
||||
) -> [Self::EF; N] {
|
||||
let res_col_major = array::from_fn(|_| self.next_interaction_mask(interaction, offsets));
|
||||
array::from_fn(|i| Self::combine_ef(res_col_major.map(|c| c[i])))
|
||||
}
|
||||
|
||||
/// Adds a constraint to the component.
|
||||
fn add_constraint<G>(&mut self, constraint: G)
|
||||
where
|
||||
Self::EF: Mul<G, Output = Self::EF>;
|
||||
|
||||
/// Combines 4 base field values into a single extension field value.
|
||||
fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF;
|
||||
}
|
||||
57
Stwo_wrapper/crates/prover/src/constraint_framework/point.rs
Normal file
57
Stwo_wrapper/crates/prover/src/constraint_framework/point.rs
Normal file
@ -0,0 +1,57 @@
|
||||
use std::ops::Mul;
|
||||
|
||||
use super::EvalAtRow;
|
||||
use crate::core::air::accumulation::PointEvaluationAccumulator;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
|
||||
use crate::core::pcs::TreeVec;
|
||||
use crate::core::ColumnVec;
|
||||
|
||||
/// Evaluates expressions at a point out of domain.
|
||||
pub struct PointEvaluator<'a> {
|
||||
pub mask: TreeVec<ColumnVec<&'a Vec<SecureField>>>,
|
||||
pub evaluation_accumulator: &'a mut PointEvaluationAccumulator,
|
||||
pub col_index: Vec<usize>,
|
||||
pub denom_inverse: SecureField,
|
||||
}
|
||||
impl<'a> PointEvaluator<'a> {
|
||||
pub fn new(
|
||||
mask: TreeVec<ColumnVec<&'a Vec<SecureField>>>,
|
||||
evaluation_accumulator: &'a mut PointEvaluationAccumulator,
|
||||
denom_inverse: SecureField,
|
||||
) -> Self {
|
||||
let col_index = vec![0; mask.len()];
|
||||
Self {
|
||||
mask,
|
||||
evaluation_accumulator,
|
||||
col_index,
|
||||
denom_inverse,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<'a> EvalAtRow for PointEvaluator<'a> {
|
||||
type F = SecureField;
|
||||
type EF = SecureField;
|
||||
|
||||
fn next_interaction_mask<const N: usize>(
|
||||
&mut self,
|
||||
interaction: usize,
|
||||
_offsets: [isize; N],
|
||||
) -> [Self::F; N] {
|
||||
let col_index = self.col_index[interaction];
|
||||
self.col_index[interaction] += 1;
|
||||
let mask = self.mask[interaction][col_index].clone();
|
||||
assert_eq!(mask.len(), N);
|
||||
mask.try_into().unwrap()
|
||||
}
|
||||
fn add_constraint<G>(&mut self, constraint: G)
|
||||
where
|
||||
Self::EF: Mul<G, Output = Self::EF>,
|
||||
{
|
||||
self.evaluation_accumulator
|
||||
.accumulate(self.denom_inverse * constraint);
|
||||
}
|
||||
fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF {
|
||||
SecureField::from_partial_evals(values)
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,106 @@
|
||||
use std::ops::Mul;
|
||||
|
||||
use num_traits::Zero;
|
||||
|
||||
use super::EvalAtRow;
|
||||
use crate::core::backend::simd::column::VeryPackedBaseColumn;
|
||||
use crate::core::backend::simd::m31::LOG_N_LANES;
|
||||
use crate::core::backend::simd::very_packed_m31::{
|
||||
VeryPackedBaseField, VeryPackedSecureField, LOG_N_VERY_PACKED_ELEMS,
|
||||
};
|
||||
use crate::core::backend::simd::SimdBackend;
|
||||
use crate::core::backend::Column;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
|
||||
use crate::core::pcs::TreeVec;
|
||||
use crate::core::poly::circle::CircleEvaluation;
|
||||
use crate::core::poly::BitReversedOrder;
|
||||
use crate::core::utils::offset_bit_reversed_circle_domain_index;
|
||||
|
||||
/// Evaluates constraints at an evaluation domain points.
|
||||
pub struct SimdDomainEvaluator<'a> {
|
||||
pub trace_eval:
|
||||
&'a TreeVec<Vec<&'a CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>>,
|
||||
pub column_index_per_interaction: Vec<usize>,
|
||||
/// The row index of the simd-vector row to evaluate the constraints at.
|
||||
pub vec_row: usize,
|
||||
pub random_coeff_powers: &'a [SecureField],
|
||||
pub row_res: VeryPackedSecureField,
|
||||
pub constraint_index: usize,
|
||||
pub domain_log_size: u32,
|
||||
pub eval_domain_log_size: u32,
|
||||
}
|
||||
impl<'a> SimdDomainEvaluator<'a> {
|
||||
pub fn new(
|
||||
trace_eval: &'a TreeVec<Vec<&CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>>,
|
||||
vec_row: usize,
|
||||
random_coeff_powers: &'a [SecureField],
|
||||
domain_log_size: u32,
|
||||
eval_log_size: u32,
|
||||
) -> Self {
|
||||
Self {
|
||||
trace_eval,
|
||||
column_index_per_interaction: vec![0; trace_eval.len()],
|
||||
vec_row,
|
||||
random_coeff_powers,
|
||||
row_res: VeryPackedSecureField::zero(),
|
||||
constraint_index: 0,
|
||||
domain_log_size,
|
||||
eval_domain_log_size: eval_log_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<'a> EvalAtRow for SimdDomainEvaluator<'a> {
|
||||
type F = VeryPackedBaseField;
|
||||
type EF = VeryPackedSecureField;
|
||||
|
||||
// TODO(spapini): Remove all boundary checks.
|
||||
fn next_interaction_mask<const N: usize>(
|
||||
&mut self,
|
||||
interaction: usize,
|
||||
offsets: [isize; N],
|
||||
) -> [Self::F; N] {
|
||||
let col_index = self.column_index_per_interaction[interaction];
|
||||
self.column_index_per_interaction[interaction] += 1;
|
||||
offsets.map(|off| {
|
||||
// If the offset is 0, we can just return the value directly from this row.
|
||||
if off == 0 {
|
||||
unsafe {
|
||||
let col = &self
|
||||
.trace_eval
|
||||
.get_unchecked(interaction)
|
||||
.get_unchecked(col_index)
|
||||
.values;
|
||||
let very_packed_col = VeryPackedBaseColumn::transform_under_ref(col);
|
||||
return *very_packed_col.data.get_unchecked(self.vec_row);
|
||||
};
|
||||
}
|
||||
// Otherwise, we need to look up the value at the offset.
|
||||
// Since the domain is bit-reversed circle domain ordered, we need to look up the value
|
||||
// at the bit-reversed natural order index at an offset.
|
||||
VeryPackedBaseField::from_array(std::array::from_fn(|i| {
|
||||
let row_index = offset_bit_reversed_circle_domain_index(
|
||||
(self.vec_row << (LOG_N_LANES + LOG_N_VERY_PACKED_ELEMS)) + i,
|
||||
self.domain_log_size,
|
||||
self.eval_domain_log_size,
|
||||
off,
|
||||
);
|
||||
self.trace_eval[interaction][col_index].at(row_index)
|
||||
}))
|
||||
})
|
||||
}
|
||||
fn add_constraint<G>(&mut self, constraint: G)
|
||||
where
|
||||
Self::EF: Mul<G, Output = Self::EF>,
|
||||
{
|
||||
self.row_res +=
|
||||
VeryPackedSecureField::broadcast(self.random_coeff_powers[self.constraint_index])
|
||||
* constraint;
|
||||
self.constraint_index += 1;
|
||||
}
|
||||
|
||||
fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF {
|
||||
VeryPackedSecureField::from_very_packed_m31s(values)
|
||||
}
|
||||
}
|
||||
297
Stwo_wrapper/crates/prover/src/core/air/accumulation.rs
Normal file
297
Stwo_wrapper/crates/prover/src/core/air/accumulation.rs
Normal file
@ -0,0 +1,297 @@
|
||||
//! Accumulators for a random linear combination of circle polynomials.
|
||||
//! Given N polynomials, u_0(P), ... u_{N-1}(P), and a random alpha, the combined polynomial is
|
||||
//! defined as
|
||||
//! f(p) = sum_i alpha^{N-1-i} u_i(P).
|
||||
|
||||
use itertools::Itertools;
|
||||
use tracing::{span, Level};
|
||||
|
||||
use crate::core::backend::{Backend, Col, Column, CpuBackend};
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::fields::secure_column::SecureColumnByCoords;
|
||||
use crate::core::fields::FieldOps;
|
||||
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, CirclePoly, SecureCirclePoly};
|
||||
use crate::core::poly::BitReversedOrder;
|
||||
use crate::core::utils::generate_secure_powers;
|
||||
|
||||
/// Accumulates N evaluations of u_i(P0) at a single point.
|
||||
/// Computes f(P0), the combined polynomial at that point.
|
||||
/// For n accumulated evaluations, the i'th evaluation is multiplied by alpha^(N-1-i).
|
||||
pub struct PointEvaluationAccumulator {
|
||||
random_coeff: SecureField,
|
||||
accumulation: SecureField,
|
||||
}
|
||||
|
||||
impl PointEvaluationAccumulator {
|
||||
/// Creates a new accumulator.
|
||||
/// `random_coeff` should be a secure random field element, drawn from the channel.
|
||||
pub fn new(random_coeff: SecureField) -> Self {
|
||||
Self {
|
||||
random_coeff,
|
||||
accumulation: SecureField::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Accumulates u_i(P0), a polynomial evaluation at a P0 in reverse order.
|
||||
pub fn accumulate(&mut self, evaluation: SecureField) {
|
||||
self.accumulation = self.accumulation * self.random_coeff + evaluation;
|
||||
}
|
||||
|
||||
pub fn finalize(self) -> SecureField {
|
||||
self.accumulation
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(ShaharS), rename terminology to constraints instead of columns.
|
||||
/// Accumulates evaluations of u_i(P), each at an evaluation domain of the size of that polynomial.
|
||||
/// Computes the coefficients of f(P).
|
||||
pub struct DomainEvaluationAccumulator<B: Backend> {
|
||||
random_coeff_powers: Vec<SecureField>,
|
||||
/// Accumulated evaluations for each log_size.
|
||||
/// Each `sub_accumulation` holds the sum over all columns i of that log_size, of
|
||||
/// `evaluation_i * alpha^(N - 1 - i)`
|
||||
/// where `N` is the total number of evaluations.
|
||||
sub_accumulations: Vec<Option<SecureColumnByCoords<B>>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> DomainEvaluationAccumulator<B> {
|
||||
/// Creates a new accumulator.
|
||||
/// `random_coeff` should be a secure random field element, drawn from the channel.
|
||||
/// `max_log_size` is the maximum log_size of the accumulated evaluations.
|
||||
pub fn new(random_coeff: SecureField, max_log_size: u32, total_columns: usize) -> Self {
|
||||
let max_log_size = max_log_size as usize;
|
||||
Self {
|
||||
random_coeff_powers: generate_secure_powers(random_coeff, total_columns),
|
||||
sub_accumulations: (0..(max_log_size + 1)).map(|_| None).collect(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets accumulators for some sizes.
|
||||
/// `n_cols_per_size` is an array of pairs (log_size, n_cols).
|
||||
/// For each entry, a [ColumnAccumulator] is returned, expecting to accumulate `n_cols`
|
||||
/// evaluations of size `log_size`.
|
||||
/// The array size, `N`, is the number of different sizes.
|
||||
pub fn columns<const N: usize>(
|
||||
&mut self,
|
||||
n_cols_per_size: [(u32, usize); N],
|
||||
) -> [ColumnAccumulator<'_, B>; N] {
|
||||
self.sub_accumulations
|
||||
.get_many_mut(n_cols_per_size.map(|(log_size, _)| log_size as usize))
|
||||
.unwrap_or_else(|e| panic!("invalid log_sizes: {}", e))
|
||||
.into_iter()
|
||||
.zip(n_cols_per_size)
|
||||
.map(|(col, (log_size, n_cols))| {
|
||||
let random_coeffs = self
|
||||
.random_coeff_powers
|
||||
.split_off(self.random_coeff_powers.len() - n_cols);
|
||||
ColumnAccumulator {
|
||||
random_coeff_powers: random_coeffs,
|
||||
col: col.get_or_insert_with(|| SecureColumnByCoords::zeros(1 << log_size)),
|
||||
}
|
||||
})
|
||||
.collect_vec()
|
||||
.try_into()
|
||||
.unwrap_or_else(|_| unreachable!())
|
||||
}
|
||||
|
||||
/// Returns the log size of the resulting polynomial.
|
||||
pub fn log_size(&self) -> u32 {
|
||||
(self.sub_accumulations.len() - 1) as u32
|
||||
}
|
||||
}
|
||||
|
||||
pub trait AccumulationOps: FieldOps<BaseField> + Sized {
|
||||
/// Accumulates other into column:
|
||||
/// column = column + other.
|
||||
fn accumulate(column: &mut SecureColumnByCoords<Self>, other: &SecureColumnByCoords<Self>);
|
||||
}
|
||||
|
||||
impl<B: Backend> DomainEvaluationAccumulator<B> {
|
||||
/// Computes f(P) as coefficients.
|
||||
pub fn finalize(self) -> SecureCirclePoly<B> {
|
||||
assert_eq!(
|
||||
self.random_coeff_powers.len(),
|
||||
0,
|
||||
"not all random coefficients were used"
|
||||
);
|
||||
let log_size = self.log_size();
|
||||
let _span = span!(Level::INFO, "Constraints interpolation").entered();
|
||||
let mut cur_poly: Option<SecureCirclePoly<B>> = None;
|
||||
let twiddles = B::precompute_twiddles(
|
||||
CanonicCoset::new(self.log_size())
|
||||
.circle_domain()
|
||||
.half_coset,
|
||||
);
|
||||
|
||||
for (log_size, values) in self.sub_accumulations.into_iter().enumerate().skip(1) {
|
||||
let Some(mut values) = values else {
|
||||
continue;
|
||||
};
|
||||
if let Some(prev_poly) = cur_poly {
|
||||
let eval = SecureColumnByCoords {
|
||||
columns: prev_poly.0.map(|c| {
|
||||
c.evaluate_with_twiddles(
|
||||
CanonicCoset::new(log_size as u32).circle_domain(),
|
||||
&twiddles,
|
||||
)
|
||||
.values
|
||||
}),
|
||||
};
|
||||
B::accumulate(&mut values, &eval);
|
||||
}
|
||||
cur_poly = Some(SecureCirclePoly(values.columns.map(|c| {
|
||||
CircleEvaluation::<B, BaseField, BitReversedOrder>::new(
|
||||
CanonicCoset::new(log_size as u32).circle_domain(),
|
||||
c,
|
||||
)
|
||||
.interpolate_with_twiddles(&twiddles)
|
||||
})));
|
||||
}
|
||||
cur_poly.unwrap_or_else(|| {
|
||||
SecureCirclePoly(std::array::from_fn(|_| {
|
||||
CirclePoly::new(Col::<B, BaseField>::zeros(1 << log_size))
|
||||
}))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// A domain accumulator for polynomials of a single size.
|
||||
pub struct ColumnAccumulator<'a, B: Backend> {
|
||||
pub random_coeff_powers: Vec<SecureField>,
|
||||
pub col: &'a mut SecureColumnByCoords<B>,
|
||||
}
|
||||
impl<'a> ColumnAccumulator<'a, CpuBackend> {
|
||||
pub fn accumulate(&mut self, index: usize, evaluation: SecureField) {
|
||||
let val = self.col.at(index) + evaluation;
|
||||
self.col.set(index, val);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::array;
|
||||
|
||||
use num_traits::Zero;
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
|
||||
use super::*;
|
||||
use crate::core::backend::cpu::CpuCircleEvaluation;
|
||||
use crate::core::circle::CirclePoint;
|
||||
use crate::core::fields::m31::{M31, P};
|
||||
use crate::qm31;
|
||||
|
||||
#[test]
|
||||
fn test_point_evaluation_accumulator() {
|
||||
// Generate a vector of random sizes with a constant seed.
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
const MAX_LOG_SIZE: u32 = 10;
|
||||
const MASK: u32 = P;
|
||||
let log_sizes = (0..100)
|
||||
.map(|_| rng.gen_range(4..MAX_LOG_SIZE))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Generate random evaluations.
|
||||
let evaluations = log_sizes
|
||||
.iter()
|
||||
.map(|_| M31::from_u32_unchecked(rng.gen::<u32>() & MASK))
|
||||
.collect::<Vec<_>>();
|
||||
let alpha = qm31!(2, 3, 4, 5);
|
||||
|
||||
// Use accumulator.
|
||||
let mut accumulator = PointEvaluationAccumulator::new(alpha);
|
||||
for (_, evaluation) in log_sizes.iter().zip(evaluations.iter()) {
|
||||
accumulator.accumulate((*evaluation).into());
|
||||
}
|
||||
let accumulator_res = accumulator.finalize();
|
||||
|
||||
// Use direct computation.
|
||||
let mut res = SecureField::default();
|
||||
for evaluation in evaluations.iter() {
|
||||
res = res * alpha + *evaluation;
|
||||
}
|
||||
|
||||
assert_eq!(accumulator_res, res);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_domain_evaluation_accumulator() {
|
||||
// Generate a vector of random sizes with a constant seed.
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
const LOG_SIZE_MIN: u32 = 4;
|
||||
const LOG_SIZE_BOUND: u32 = 10;
|
||||
const MASK: u32 = P;
|
||||
let mut log_sizes = (0..100)
|
||||
.map(|_| rng.gen_range(LOG_SIZE_MIN..LOG_SIZE_BOUND))
|
||||
.collect::<Vec<_>>();
|
||||
log_sizes.sort();
|
||||
|
||||
// Generate random evaluations.
|
||||
let evaluations = log_sizes
|
||||
.iter()
|
||||
.map(|log_size| {
|
||||
(0..(1 << *log_size))
|
||||
.map(|_| M31::from_u32_unchecked(rng.gen::<u32>() & MASK))
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let alpha = qm31!(2, 3, 4, 5);
|
||||
|
||||
// Use accumulator.
|
||||
let mut accumulator = DomainEvaluationAccumulator::<CpuBackend>::new(
|
||||
alpha,
|
||||
LOG_SIZE_BOUND,
|
||||
evaluations.len(),
|
||||
);
|
||||
let n_cols_per_size: [(u32, usize); (LOG_SIZE_BOUND - LOG_SIZE_MIN) as usize] =
|
||||
array::from_fn(|i| {
|
||||
let current_log_size = LOG_SIZE_MIN + i as u32;
|
||||
let n_cols = log_sizes
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|&log_size| log_size == current_log_size)
|
||||
.count();
|
||||
(current_log_size, n_cols)
|
||||
});
|
||||
let mut cols = accumulator.columns(n_cols_per_size);
|
||||
let mut eval_chunk_offset = 0;
|
||||
for (log_size, n_cols) in n_cols_per_size.iter() {
|
||||
for index in 0..(1 << log_size) {
|
||||
let mut val = SecureField::zero();
|
||||
for (eval_index, (col_log_size, evaluation)) in
|
||||
log_sizes.iter().zip(evaluations.iter()).enumerate()
|
||||
{
|
||||
if *log_size != *col_log_size {
|
||||
continue;
|
||||
}
|
||||
|
||||
// The random coefficient powers chunk is in regular order.
|
||||
let random_coeff_chunk =
|
||||
&cols[(log_size - LOG_SIZE_MIN) as usize].random_coeff_powers;
|
||||
val += random_coeff_chunk
|
||||
[random_coeff_chunk.len() - 1 - (eval_index - eval_chunk_offset)]
|
||||
* evaluation[index];
|
||||
}
|
||||
cols[(log_size - LOG_SIZE_MIN) as usize].accumulate(index, val);
|
||||
}
|
||||
eval_chunk_offset += n_cols;
|
||||
}
|
||||
let accumulator_poly = accumulator.finalize();
|
||||
|
||||
// Pick an arbitrary sample point.
|
||||
let point = CirclePoint::<SecureField>::get_point(98989892);
|
||||
let accumulator_res = accumulator_poly.eval_at_point(point);
|
||||
|
||||
// Use direct computation.
|
||||
let mut res = SecureField::default();
|
||||
for (log_size, values) in log_sizes.into_iter().zip(evaluations) {
|
||||
res = res * alpha
|
||||
+ CpuCircleEvaluation::new(CanonicCoset::new(log_size).circle_domain(), values)
|
||||
.interpolate()
|
||||
.eval_at_point(point);
|
||||
}
|
||||
|
||||
assert_eq!(accumulator_res, res);
|
||||
}
|
||||
}
|
||||
80
Stwo_wrapper/crates/prover/src/core/air/components.rs
Normal file
80
Stwo_wrapper/crates/prover/src/core/air/components.rs
Normal file
@ -0,0 +1,80 @@
|
||||
use itertools::Itertools;
|
||||
|
||||
use super::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
|
||||
use super::{Component, ComponentProver, Trace};
|
||||
use crate::core::backend::Backend;
|
||||
use crate::core::circle::CirclePoint;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::pcs::TreeVec;
|
||||
use crate::core::poly::circle::SecureCirclePoly;
|
||||
use crate::core::ColumnVec;
|
||||
|
||||
pub struct Components<'a>(pub Vec<&'a dyn Component>);
|
||||
|
||||
impl<'a> Components<'a> {
|
||||
pub fn composition_log_degree_bound(&self) -> u32 {
|
||||
self.0
|
||||
.iter()
|
||||
.map(|component| component.max_constraint_log_degree_bound())
|
||||
.max()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub fn mask_points(
|
||||
&self,
|
||||
point: CirclePoint<SecureField>,
|
||||
) -> TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>> {
|
||||
TreeVec::concat_cols(self.0.iter().map(|component| component.mask_points(point)))
|
||||
}
|
||||
|
||||
pub fn eval_composition_polynomial_at_point(
|
||||
&self,
|
||||
point: CirclePoint<SecureField>,
|
||||
mask_values: &TreeVec<Vec<Vec<SecureField>>>,
|
||||
random_coeff: SecureField,
|
||||
) -> SecureField {
|
||||
//accumulator for the random linear comination over powers of random_coeff
|
||||
let mut evaluation_accumulator = PointEvaluationAccumulator::new(random_coeff);
|
||||
|
||||
for component in &self.0 {
|
||||
component.evaluate_constraint_quotients_at_point(
|
||||
point,
|
||||
mask_values,
|
||||
&mut evaluation_accumulator,
|
||||
)
|
||||
}
|
||||
evaluation_accumulator.finalize()
|
||||
}
|
||||
|
||||
pub fn column_log_sizes(&self) -> TreeVec<ColumnVec<u32>> {
|
||||
TreeVec::concat_cols(
|
||||
self.0
|
||||
.iter()
|
||||
.map(|component| component.trace_log_degree_bounds()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ComponentProvers<'a, B: Backend>(pub Vec<&'a dyn ComponentProver<B>>);
|
||||
|
||||
impl<'a, B: Backend> ComponentProvers<'a, B> {
|
||||
pub fn components(&self) -> Components<'_> {
|
||||
Components(self.0.iter().map(|c| *c as &dyn Component).collect_vec())
|
||||
}
|
||||
pub fn compute_composition_polynomial(
|
||||
&self,
|
||||
random_coeff: SecureField,
|
||||
trace: &Trace<'_, B>,
|
||||
) -> SecureCirclePoly<B> {
|
||||
let total_constraints: usize = self.0.iter().map(|c| c.n_constraints()).sum();
|
||||
let mut accumulator = DomainEvaluationAccumulator::new(
|
||||
random_coeff,
|
||||
self.components().composition_log_degree_bound(),
|
||||
total_constraints,
|
||||
);
|
||||
for component in &self.0 {
|
||||
component.evaluate_constraint_quotients_on_domain(trace, &mut accumulator)
|
||||
}
|
||||
accumulator.finalize()
|
||||
}
|
||||
}
|
||||
91
Stwo_wrapper/crates/prover/src/core/air/mask.rs
Normal file
91
Stwo_wrapper/crates/prover/src/core/air/mask.rs
Normal file
@ -0,0 +1,91 @@
|
||||
use std::collections::HashSet;
|
||||
use std::vec;
|
||||
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::core::circle::CirclePoint;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::poly::circle::CanonicCoset;
|
||||
use crate::core::ColumnVec;
|
||||
|
||||
/// Mask holds a vector with an entry for each column.
|
||||
/// Each entry holds a list of mask items, which are the offsets of the mask at that column.
|
||||
type Mask = ColumnVec<Vec<usize>>;
|
||||
|
||||
/// Returns the same point for each mask item.
|
||||
/// Should be used where all the mask items has no shift from the constraint point.
|
||||
pub fn fixed_mask_points(
|
||||
mask: &Mask,
|
||||
point: CirclePoint<SecureField>,
|
||||
) -> ColumnVec<Vec<CirclePoint<SecureField>>> {
|
||||
assert_eq!(
|
||||
mask.iter()
|
||||
.flat_map(|mask_entry| mask_entry.iter().collect::<HashSet<_>>())
|
||||
.collect::<HashSet<&usize>>()
|
||||
.into_iter()
|
||||
.collect_vec(),
|
||||
vec![&0]
|
||||
);
|
||||
mask.iter()
|
||||
.map(|mask_entry| mask_entry.iter().map(|_| point).collect())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// For each mask item returns the point shifted by the domain initial point of the column.
|
||||
/// Should be used where the mask items are shifted from the constraint point.
|
||||
pub fn shifted_mask_points(
|
||||
mask: &Mask,
|
||||
domains: &[CanonicCoset],
|
||||
point: CirclePoint<SecureField>,
|
||||
) -> ColumnVec<Vec<CirclePoint<SecureField>>> {
|
||||
mask.iter()
|
||||
.zip(domains.iter())
|
||||
.map(|(mask_entry, domain)| {
|
||||
mask_entry
|
||||
.iter()
|
||||
.map(|mask_item| point + domain.at(*mask_item).into_ef())
|
||||
.collect()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::core::air::mask::{fixed_mask_points, shifted_mask_points};
|
||||
use crate::core::circle::CirclePoint;
|
||||
use crate::core::poly::circle::CanonicCoset;
|
||||
|
||||
#[test]
|
||||
fn test_mask_fixed_points() {
|
||||
let mask = vec![vec![0], vec![0]];
|
||||
let constraint_point = CirclePoint::get_point(1234);
|
||||
|
||||
let points = fixed_mask_points(&mask, constraint_point);
|
||||
|
||||
assert_eq!(points.len(), 2);
|
||||
assert_eq!(points[0].len(), 1);
|
||||
assert_eq!(points[1].len(), 1);
|
||||
assert_eq!(points[0][0], constraint_point);
|
||||
assert_eq!(points[1][0], constraint_point);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mask_shifted_points() {
|
||||
let mask = vec![vec![0, 1], vec![0, 1, 2]];
|
||||
let constraint_point = CirclePoint::get_point(1234);
|
||||
let domains = (0..mask.len() as u32)
|
||||
.map(|i| CanonicCoset::new(7 + i))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let points = shifted_mask_points(&mask, &domains, constraint_point);
|
||||
|
||||
assert_eq!(points.len(), 2);
|
||||
assert_eq!(points[0].len(), 2);
|
||||
assert_eq!(points[1].len(), 3);
|
||||
assert_eq!(points[0][0], constraint_point + domains[0].at(0).into_ef());
|
||||
assert_eq!(points[0][1], constraint_point + domains[0].at(1).into_ef());
|
||||
assert_eq!(points[1][0], constraint_point + domains[1].at(0).into_ef());
|
||||
assert_eq!(points[1][1], constraint_point + domains[1].at(1).into_ef());
|
||||
assert_eq!(points[1][2], constraint_point + domains[1].at(2).into_ef());
|
||||
}
|
||||
}
|
||||
76
Stwo_wrapper/crates/prover/src/core/air/mod.rs
Normal file
76
Stwo_wrapper/crates/prover/src/core/air/mod.rs
Normal file
@ -0,0 +1,76 @@
|
||||
pub use components::{ComponentProvers, Components};
|
||||
|
||||
use self::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
|
||||
use super::backend::Backend;
|
||||
use super::circle::CirclePoint;
|
||||
use super::fields::m31::BaseField;
|
||||
use super::fields::qm31::SecureField;
|
||||
use super::pcs::TreeVec;
|
||||
use super::poly::circle::{CircleEvaluation, CirclePoly};
|
||||
use super::poly::BitReversedOrder;
|
||||
use super::ColumnVec;
|
||||
|
||||
pub mod accumulation;
|
||||
mod components;
|
||||
pub mod mask;
|
||||
|
||||
/// Arithmetic Intermediate Representation (AIR).
|
||||
/// An Air instance is assumed to already contain all the information needed to
|
||||
/// evaluate the constraints.
|
||||
/// For instance, all interaction elements are assumed to be present in it.
|
||||
/// Therefore, an AIR is generated only after the initial trace commitment phase.
|
||||
// TODO(spapini): consider renaming this struct.
|
||||
pub trait Air {
|
||||
fn components(&self) -> Vec<&dyn Component>;
|
||||
}
|
||||
|
||||
pub trait AirProver<B: Backend>: Air {
|
||||
fn component_provers(&self) -> Vec<&dyn ComponentProver<B>>;
|
||||
}
|
||||
|
||||
/// A component is a set of trace columns of various sizes along with a set of
|
||||
/// constraints on them.
|
||||
pub trait Component {
|
||||
fn n_constraints(&self) -> usize;
|
||||
|
||||
fn max_constraint_log_degree_bound(&self) -> u32;
|
||||
|
||||
/// Returns the degree bounds of each trace column. The returned TreeVec should be of size
|
||||
/// `n_interaction_phases`.
|
||||
fn trace_log_degree_bounds(&self) -> TreeVec<ColumnVec<u32>>;
|
||||
|
||||
/// Returns the mask points for each trace column. The returned TreeVec should be of size
|
||||
/// `n_interaction_phases`.
|
||||
fn mask_points(
|
||||
&self,
|
||||
point: CirclePoint<SecureField>,
|
||||
) -> TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>>;
|
||||
|
||||
/// Evaluates the constraint quotients combination of the component at a point.
|
||||
fn evaluate_constraint_quotients_at_point(
|
||||
&self,
|
||||
point: CirclePoint<SecureField>,
|
||||
mask: &TreeVec<ColumnVec<Vec<SecureField>>>,
|
||||
evaluation_accumulator: &mut PointEvaluationAccumulator,
|
||||
);
|
||||
}
|
||||
|
||||
pub trait ComponentProver<B: Backend>: Component {
|
||||
/// Evaluates the constraint quotients of the component on the evaluation domain.
|
||||
/// Accumulates quotients in `evaluation_accumulator`.
|
||||
fn evaluate_constraint_quotients_on_domain(
|
||||
&self,
|
||||
trace: &Trace<'_, B>,
|
||||
evaluation_accumulator: &mut DomainEvaluationAccumulator<B>,
|
||||
);
|
||||
}
|
||||
|
||||
/// The set of polynomials that make up the trace.
|
||||
///
|
||||
/// Each polynomial is stored both in a coefficients, and evaluations form (for efficiency)
|
||||
pub struct Trace<'a, B: Backend> {
|
||||
/// Polynomials for each column.
|
||||
pub polys: TreeVec<ColumnVec<&'a CirclePoly<B>>>,
|
||||
/// Evaluations for each column (evaluated on their commitment domains).
|
||||
pub evals: TreeVec<ColumnVec<&'a CircleEvaluation<B, BaseField, BitReversedOrder>>>,
|
||||
}
|
||||
@ -0,0 +1,12 @@
|
||||
use super::CpuBackend;
|
||||
use crate::core::air::accumulation::AccumulationOps;
|
||||
use crate::core::fields::secure_column::SecureColumnByCoords;
|
||||
|
||||
impl AccumulationOps for CpuBackend {
|
||||
fn accumulate(column: &mut SecureColumnByCoords<Self>, other: &SecureColumnByCoords<Self>) {
|
||||
for i in 0..column.len() {
|
||||
let res_coeff = column.at(i) + other.at(i);
|
||||
column.set(i, res_coeff);
|
||||
}
|
||||
}
|
||||
}
|
||||
24
Stwo_wrapper/crates/prover/src/core/backend/cpu/blake2s.rs
Normal file
24
Stwo_wrapper/crates/prover/src/core/backend/cpu/blake2s.rs
Normal file
@ -0,0 +1,24 @@
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::core::backend::CpuBackend;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::vcs::blake2_hash::Blake2sHash;
|
||||
use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher;
|
||||
use crate::core::vcs::ops::{MerkleHasher, MerkleOps};
|
||||
|
||||
impl MerkleOps<Blake2sMerkleHasher> for CpuBackend {
|
||||
fn commit_on_layer(
|
||||
log_size: u32,
|
||||
prev_layer: Option<&Vec<Blake2sHash>>,
|
||||
columns: &[&Vec<BaseField>],
|
||||
) -> Vec<Blake2sHash> {
|
||||
(0..(1 << log_size))
|
||||
.map(|i| {
|
||||
Blake2sMerkleHasher::hash_node(
|
||||
prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])),
|
||||
&columns.iter().map(|column| column[i]).collect_vec(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
376
Stwo_wrapper/crates/prover/src/core/backend/cpu/circle.rs
Normal file
376
Stwo_wrapper/crates/prover/src/core/backend/cpu/circle.rs
Normal file
@ -0,0 +1,376 @@
|
||||
use num_traits::Zero;
|
||||
|
||||
use super::CpuBackend;
|
||||
use crate::core::backend::{Col, ColumnOps};
|
||||
use crate::core::circle::{CirclePoint, Coset};
|
||||
use crate::core::fft::{butterfly, ibutterfly};
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::fields::{ExtensionOf, FieldExpOps};
|
||||
use crate::core::poly::circle::{
|
||||
CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly, PolyOps,
|
||||
};
|
||||
use crate::core::poly::twiddles::TwiddleTree;
|
||||
use crate::core::poly::utils::{domain_line_twiddles_from_tree, fold};
|
||||
use crate::core::poly::BitReversedOrder;
|
||||
use crate::core::utils::{bit_reverse, coset_order_to_circle_domain_order};
|
||||
|
||||
impl PolyOps for CpuBackend {
|
||||
type Twiddles = Vec<BaseField>;
|
||||
|
||||
fn new_canonical_ordered(
|
||||
coset: CanonicCoset,
|
||||
values: Col<Self, BaseField>,
|
||||
) -> CircleEvaluation<Self, BaseField, BitReversedOrder> {
|
||||
let domain = coset.circle_domain();
|
||||
assert_eq!(values.len(), domain.size());
|
||||
let mut new_values = coset_order_to_circle_domain_order(&values);
|
||||
CpuBackend::bit_reverse_column(&mut new_values);
|
||||
CircleEvaluation::new(domain, new_values)
|
||||
}
|
||||
|
||||
fn interpolate(
|
||||
eval: CircleEvaluation<Self, BaseField, BitReversedOrder>,
|
||||
twiddles: &TwiddleTree<Self>,
|
||||
) -> CirclePoly<Self> {
|
||||
assert!(eval.domain.half_coset.is_doubling_of(twiddles.root_coset));
|
||||
|
||||
let mut values = eval.values;
|
||||
|
||||
if eval.domain.log_size() == 1 {
|
||||
let y = eval.domain.half_coset.initial.y;
|
||||
let n = BaseField::from(2);
|
||||
let yn_inv = (y * n).inverse();
|
||||
let y_inv = yn_inv * n;
|
||||
let n_inv = yn_inv * y;
|
||||
let (mut v0, mut v1) = (values[0], values[1]);
|
||||
ibutterfly(&mut v0, &mut v1, y_inv);
|
||||
return CirclePoly::new(vec![v0 * n_inv, v1 * n_inv]);
|
||||
}
|
||||
|
||||
if eval.domain.log_size() == 2 {
|
||||
let CirclePoint { x, y } = eval.domain.half_coset.initial;
|
||||
let n = BaseField::from(4);
|
||||
let xyn_inv = (x * y * n).inverse();
|
||||
let x_inv = xyn_inv * y * n;
|
||||
let y_inv = xyn_inv * x * n;
|
||||
let n_inv = xyn_inv * x * y;
|
||||
let (mut v0, mut v1, mut v2, mut v3) = (values[0], values[1], values[2], values[3]);
|
||||
ibutterfly(&mut v0, &mut v1, y_inv);
|
||||
ibutterfly(&mut v2, &mut v3, -y_inv);
|
||||
ibutterfly(&mut v0, &mut v2, x_inv);
|
||||
ibutterfly(&mut v1, &mut v3, x_inv);
|
||||
return CirclePoly::new(vec![v0 * n_inv, v1 * n_inv, v2 * n_inv, v3 * n_inv]);
|
||||
}
|
||||
|
||||
let line_twiddles = domain_line_twiddles_from_tree(eval.domain, &twiddles.itwiddles);
|
||||
let circle_twiddles = circle_twiddles_from_line_twiddles(line_twiddles[0]);
|
||||
|
||||
for (h, t) in circle_twiddles.enumerate() {
|
||||
fft_layer_loop(&mut values, 0, h, t, ibutterfly);
|
||||
}
|
||||
for (layer, layer_twiddles) in line_twiddles.into_iter().enumerate() {
|
||||
for (h, &t) in layer_twiddles.iter().enumerate() {
|
||||
fft_layer_loop(&mut values, layer + 1, h, t, ibutterfly);
|
||||
}
|
||||
}
|
||||
|
||||
// Divide all values by 2^log_size.
|
||||
let inv = BaseField::from_u32_unchecked(eval.domain.size() as u32).inverse();
|
||||
for val in &mut values {
|
||||
*val *= inv;
|
||||
}
|
||||
|
||||
CirclePoly::new(values)
|
||||
}
|
||||
|
||||
fn eval_at_point(poly: &CirclePoly<Self>, point: CirclePoint<SecureField>) -> SecureField {
|
||||
if poly.log_size() == 0 {
|
||||
return poly.coeffs[0].into();
|
||||
}
|
||||
|
||||
let mut mappings = vec![point.y];
|
||||
let mut x = point.x;
|
||||
for _ in 1..poly.log_size() {
|
||||
mappings.push(x);
|
||||
x = CirclePoint::double_x(x);
|
||||
}
|
||||
mappings.reverse();
|
||||
|
||||
fold(&poly.coeffs, &mappings)
|
||||
}
|
||||
|
||||
fn extend(poly: &CirclePoly<Self>, log_size: u32) -> CirclePoly<Self> {
|
||||
assert!(log_size >= poly.log_size());
|
||||
let mut coeffs = Vec::with_capacity(1 << log_size);
|
||||
coeffs.extend_from_slice(&poly.coeffs);
|
||||
coeffs.resize(1 << log_size, BaseField::zero());
|
||||
CirclePoly::new(coeffs)
|
||||
}
|
||||
|
||||
fn evaluate(
|
||||
poly: &CirclePoly<Self>,
|
||||
domain: CircleDomain,
|
||||
twiddles: &TwiddleTree<Self>,
|
||||
) -> CircleEvaluation<Self, BaseField, BitReversedOrder> {
|
||||
assert!(domain.half_coset.is_doubling_of(twiddles.root_coset));
|
||||
|
||||
let mut values = poly.extend(domain.log_size()).coeffs;
|
||||
|
||||
if domain.log_size() == 1 {
|
||||
let (mut v0, mut v1) = (values[0], values[1]);
|
||||
butterfly(&mut v0, &mut v1, domain.half_coset.initial.y);
|
||||
return CircleEvaluation::new(domain, vec![v0, v1]);
|
||||
}
|
||||
|
||||
if domain.log_size() == 2 {
|
||||
let (mut v0, mut v1, mut v2, mut v3) = (values[0], values[1], values[2], values[3]);
|
||||
let CirclePoint { x, y } = domain.half_coset.initial;
|
||||
butterfly(&mut v0, &mut v2, x);
|
||||
butterfly(&mut v1, &mut v3, x);
|
||||
butterfly(&mut v0, &mut v1, y);
|
||||
butterfly(&mut v2, &mut v3, -y);
|
||||
return CircleEvaluation::new(domain, vec![v0, v1, v2, v3]);
|
||||
}
|
||||
|
||||
let line_twiddles = domain_line_twiddles_from_tree(domain, &twiddles.twiddles);
|
||||
let circle_twiddles = circle_twiddles_from_line_twiddles(line_twiddles[0]);
|
||||
|
||||
for (layer, layer_twiddles) in line_twiddles.iter().enumerate().rev() {
|
||||
for (h, &t) in layer_twiddles.iter().enumerate() {
|
||||
fft_layer_loop(&mut values, layer + 1, h, t, butterfly);
|
||||
}
|
||||
}
|
||||
for (h, t) in circle_twiddles.enumerate() {
|
||||
fft_layer_loop(&mut values, 0, h, t, butterfly);
|
||||
}
|
||||
|
||||
CircleEvaluation::new(domain, values)
|
||||
}
|
||||
|
||||
fn precompute_twiddles(mut coset: Coset) -> TwiddleTree<Self> {
|
||||
const CHUNK_LOG_SIZE: usize = 12;
|
||||
const CHUNK_SIZE: usize = 1 << CHUNK_LOG_SIZE;
|
||||
|
||||
let root_coset = coset;
|
||||
let mut twiddles = Vec::with_capacity(coset.size());
|
||||
for _ in 0..coset.log_size() {
|
||||
let i0 = twiddles.len();
|
||||
twiddles.extend(
|
||||
coset
|
||||
.iter()
|
||||
.take(coset.size() / 2)
|
||||
.map(|p| p.x)
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
bit_reverse(&mut twiddles[i0..]);
|
||||
coset = coset.double();
|
||||
}
|
||||
twiddles.push(1.into());
|
||||
|
||||
// Inverse twiddles.
|
||||
// Fallback to the non-chunked version if the domain is not big enough.
|
||||
if CHUNK_SIZE > coset.size() {
|
||||
let itwiddles = twiddles.iter().map(|&t| t.inverse()).collect();
|
||||
return TwiddleTree {
|
||||
root_coset,
|
||||
twiddles,
|
||||
itwiddles,
|
||||
};
|
||||
}
|
||||
|
||||
let mut itwiddles = vec![BaseField::zero(); twiddles.len()];
|
||||
twiddles
|
||||
.array_chunks::<CHUNK_SIZE>()
|
||||
.zip(itwiddles.array_chunks_mut::<CHUNK_SIZE>())
|
||||
.for_each(|(src, dst)| {
|
||||
BaseField::batch_inverse(src, dst);
|
||||
});
|
||||
|
||||
TwiddleTree {
|
||||
root_coset,
|
||||
twiddles,
|
||||
itwiddles,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn fft_layer_loop(
|
||||
values: &mut [BaseField],
|
||||
i: usize,
|
||||
h: usize,
|
||||
t: BaseField,
|
||||
butterfly_fn: impl Fn(&mut BaseField, &mut BaseField, BaseField),
|
||||
) {
|
||||
for l in 0..(1 << i) {
|
||||
let idx0 = (h << (i + 1)) + l;
|
||||
let idx1 = idx0 + (1 << i);
|
||||
let (mut val0, mut val1) = (values[idx0], values[idx1]);
|
||||
butterfly_fn(&mut val0, &mut val1, t);
|
||||
(values[idx0], values[idx1]) = (val0, val1);
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the circle twiddles layer (layer 0) from the first line twiddles layer (layer 1).
|
||||
///
|
||||
/// Only works for line twiddles generated from a domain with size `>4`.
|
||||
fn circle_twiddles_from_line_twiddles(
|
||||
first_line_twiddles: &[BaseField],
|
||||
) -> impl Iterator<Item = BaseField> + '_ {
|
||||
// The twiddles for layer 0 can be computed from the twiddles for layer 1.
|
||||
// Since the twiddles are bit reversed, we consider the circle domain in bit reversed order.
|
||||
// Each consecutive 4 points in the bit reversed order of a coset form a circle coset of size 4.
|
||||
// A circle coset of size 4 in bit reversed order looks like this:
|
||||
// [(x, y), (-x, -y), (y, -x), (-y, x)]
|
||||
// Note: This relation is derived from the fact that `M31_CIRCLE_GEN`.repeated_double(ORDER / 4)
|
||||
// == (-1,0), and not (0,1). (0,1) would yield another relation.
|
||||
// The twiddles for layer 0 are the y coordinates:
|
||||
// [y, -y, -x, x]
|
||||
// The twiddles for layer 1 in bit reversed order are the x coordinates of the even indices
|
||||
// points:
|
||||
// [x, y]
|
||||
// Works also for inverse of the twiddles.
|
||||
first_line_twiddles
|
||||
.iter()
|
||||
.array_chunks()
|
||||
.flat_map(|[&x, &y]| [y, -y, -x, x])
|
||||
}
|
||||
|
||||
impl<F: ExtensionOf<BaseField>, EvalOrder> IntoIterator
|
||||
for CircleEvaluation<CpuBackend, F, EvalOrder>
|
||||
{
|
||||
type Item = F;
|
||||
type IntoIter = std::vec::IntoIter<F>;
|
||||
|
||||
/// Creates a consuming iterator over the evaluations.
|
||||
///
|
||||
/// Evaluations are returned in the same order as elements of the domain.
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.values.into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::iter::zip;
|
||||
|
||||
use num_traits::One;
|
||||
|
||||
use crate::core::backend::cpu::CpuCirclePoly;
|
||||
use crate::core::circle::CirclePoint;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::poly::circle::CanonicCoset;
|
||||
|
||||
#[test]
|
||||
fn test_eval_at_point_with_4_coeffs() {
|
||||
// Represents the polynomial `1 + 2y + 3x + 4xy`.
|
||||
// Note coefficients are passed in bit reversed order.
|
||||
let poly = CpuCirclePoly::new([1, 3, 2, 4].map(BaseField::from).to_vec());
|
||||
let x = BaseField::from(5).into();
|
||||
let y = BaseField::from(8).into();
|
||||
|
||||
let eval = poly.eval_at_point(CirclePoint { x, y });
|
||||
|
||||
assert_eq!(
|
||||
eval,
|
||||
poly.coeffs[0] + poly.coeffs[1] * y + poly.coeffs[2] * x + poly.coeffs[3] * x * y
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_eval_at_point_with_2_coeffs() {
|
||||
// Represents the polynomial `1 + 2y`.
|
||||
let poly = CpuCirclePoly::new(vec![BaseField::from(1), BaseField::from(2)]);
|
||||
let x = BaseField::from(5).into();
|
||||
let y = BaseField::from(8).into();
|
||||
|
||||
let eval = poly.eval_at_point(CirclePoint { x, y });
|
||||
|
||||
assert_eq!(eval, poly.coeffs[0] + poly.coeffs[1] * y);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_eval_at_point_with_1_coeff() {
|
||||
// Represents the polynomial `1`.
|
||||
let poly = CpuCirclePoly::new(vec![BaseField::one()]);
|
||||
let x = BaseField::from(5).into();
|
||||
let y = BaseField::from(8).into();
|
||||
|
||||
let eval = poly.eval_at_point(CirclePoint { x, y });
|
||||
|
||||
assert_eq!(eval, SecureField::one());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_evaluate_2_coeffs() {
|
||||
let domain = CanonicCoset::new(1).circle_domain();
|
||||
let poly = CpuCirclePoly::new((1..=2).map(BaseField::from).collect());
|
||||
|
||||
let evaluation = poly.clone().evaluate(domain).bit_reverse();
|
||||
|
||||
for (i, (p, eval)) in zip(domain, evaluation).enumerate() {
|
||||
let eval: SecureField = eval.into();
|
||||
assert_eq!(eval, poly.eval_at_point(p.into_ef()), "mismatch at i={i}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_evaluate_4_coeffs() {
|
||||
let domain = CanonicCoset::new(2).circle_domain();
|
||||
let poly = CpuCirclePoly::new((1..=4).map(BaseField::from).collect());
|
||||
|
||||
let evaluation = poly.clone().evaluate(domain).bit_reverse();
|
||||
|
||||
for (i, (x, eval)) in zip(domain, evaluation).enumerate() {
|
||||
let eval: SecureField = eval.into();
|
||||
assert_eq!(eval, poly.eval_at_point(x.into_ef()), "mismatch at i={i}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_evaluate_8_coeffs() {
|
||||
let domain = CanonicCoset::new(3).circle_domain();
|
||||
let poly = CpuCirclePoly::new((1..=8).map(BaseField::from).collect());
|
||||
|
||||
let evaluation = poly.clone().evaluate(domain).bit_reverse();
|
||||
|
||||
for (i, (x, eval)) in zip(domain, evaluation).enumerate() {
|
||||
let eval: SecureField = eval.into();
|
||||
assert_eq!(eval, poly.eval_at_point(x.into_ef()), "mismatch at i={i}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_interpolate_2_evals() {
|
||||
let poly = CpuCirclePoly::new(vec![BaseField::one(), BaseField::from(2)]);
|
||||
let domain = CanonicCoset::new(1).circle_domain();
|
||||
let evals = poly.clone().evaluate(domain);
|
||||
|
||||
let interpolated_poly = evals.interpolate();
|
||||
|
||||
assert_eq!(interpolated_poly.coeffs, poly.coeffs);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_interpolate_4_evals() {
|
||||
let poly = CpuCirclePoly::new((1..=4).map(BaseField::from).collect());
|
||||
let domain = CanonicCoset::new(2).circle_domain();
|
||||
let evals = poly.clone().evaluate(domain);
|
||||
|
||||
let interpolated_poly = evals.interpolate();
|
||||
|
||||
assert_eq!(interpolated_poly.coeffs, poly.coeffs);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_interpolate_8_evals() {
|
||||
let poly = CpuCirclePoly::new((1..=8).map(BaseField::from).collect());
|
||||
let domain = CanonicCoset::new(3).circle_domain();
|
||||
let evals = poly.clone().evaluate(domain);
|
||||
|
||||
let interpolated_poly = evals.interpolate();
|
||||
|
||||
assert_eq!(interpolated_poly.coeffs, poly.coeffs);
|
||||
}
|
||||
}
|
||||
144
Stwo_wrapper/crates/prover/src/core/backend/cpu/fri.rs
Normal file
144
Stwo_wrapper/crates/prover/src/core/backend/cpu/fri.rs
Normal file
@ -0,0 +1,144 @@
|
||||
use super::CpuBackend;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::fields::secure_column::SecureColumnByCoords;
|
||||
use crate::core::fri::{fold_circle_into_line, fold_line, FriOps};
|
||||
use crate::core::poly::circle::SecureEvaluation;
|
||||
use crate::core::poly::line::LineEvaluation;
|
||||
use crate::core::poly::twiddles::TwiddleTree;
|
||||
use crate::core::poly::BitReversedOrder;
|
||||
|
||||
// TODO(spapini): Optimized these functions as well.
|
||||
impl FriOps for CpuBackend {
|
||||
fn fold_line(
|
||||
eval: &LineEvaluation<Self>,
|
||||
alpha: SecureField,
|
||||
_twiddles: &TwiddleTree<Self>,
|
||||
) -> LineEvaluation<Self> {
|
||||
fold_line(eval, alpha)
|
||||
}
|
||||
fn fold_circle_into_line(
|
||||
dst: &mut LineEvaluation<Self>,
|
||||
src: &SecureEvaluation<Self, BitReversedOrder>,
|
||||
alpha: SecureField,
|
||||
_twiddles: &TwiddleTree<Self>,
|
||||
) {
|
||||
fold_circle_into_line(dst, src, alpha)
|
||||
}
|
||||
|
||||
fn decompose(
|
||||
eval: &SecureEvaluation<Self, BitReversedOrder>,
|
||||
) -> (SecureEvaluation<Self, BitReversedOrder>, SecureField) {
|
||||
let lambda = Self::decomposition_coefficient(eval);
|
||||
let mut g_values = unsafe { SecureColumnByCoords::<Self>::uninitialized(eval.len()) };
|
||||
|
||||
let domain_size = eval.len();
|
||||
let half_domain_size = domain_size / 2;
|
||||
|
||||
for i in 0..half_domain_size {
|
||||
let x = eval.values.at(i);
|
||||
let val = x - lambda;
|
||||
g_values.set(i, val);
|
||||
}
|
||||
for i in half_domain_size..domain_size {
|
||||
let x = eval.values.at(i);
|
||||
let val = x + lambda;
|
||||
g_values.set(i, val);
|
||||
}
|
||||
|
||||
let g = SecureEvaluation::new(eval.domain, g_values);
|
||||
(g, lambda)
|
||||
}
|
||||
}
|
||||
|
||||
impl CpuBackend {
|
||||
/// Used to decompose a general polynomial to a polynomial inside the fft-space, and
|
||||
/// the remainder terms.
|
||||
/// A coset-diff on a [`CirclePoly`] that is in the FFT space will return zero.
|
||||
///
|
||||
/// Let N be the domain size, Let h be a coset size N/2. Using lemma #7 from the CircleStark
|
||||
/// paper, <f,V_h> = lambda<V_h,V_h> = lambda\*N => lambda = f(0)\*V_h(0) + f(1)*V_h(1) + .. +
|
||||
/// f(N-1)\*V_h(N-1). The Vanishing polynomial of a cannonic coset sized half the circle
|
||||
/// domain,evaluated on the circle domain, is [(1, -1, -1, 1)] repeating. This becomes
|
||||
/// alternating [+-1] in our NaturalOrder, and [(+, +, +, ... , -, -)] in bit reverse.
|
||||
/// Explicitly, lambda\*N = sum(+f(0..N/2)) + sum(-f(N/2..)).
|
||||
///
|
||||
/// # Warning
|
||||
/// This function assumes the blowupfactor is 2
|
||||
///
|
||||
/// [`CirclePoly`]: crate::core::poly::circle::CirclePoly
|
||||
fn decomposition_coefficient(eval: &SecureEvaluation<Self, BitReversedOrder>) -> SecureField {
|
||||
let domain_size = 1 << eval.domain.log_size();
|
||||
let half_domain_size = domain_size / 2;
|
||||
|
||||
// eval is in bit-reverse, hence all the positive factors are in the first half, opposite to
|
||||
// the latter.
|
||||
let a_sum = (0..half_domain_size)
|
||||
.map(|i| eval.values.at(i))
|
||||
.sum::<SecureField>();
|
||||
let b_sum = (half_domain_size..domain_size)
|
||||
.map(|i| eval.values.at(i))
|
||||
.sum::<SecureField>();
|
||||
|
||||
// lambda = sum(+-f(p)) / 2N.
|
||||
(a_sum - b_sum) / BaseField::from_u32_unchecked(domain_size as u32)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use num_traits::Zero;
|
||||
|
||||
use crate::core::backend::cpu::{CpuCircleEvaluation, CpuCirclePoly};
|
||||
use crate::core::backend::CpuBackend;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::fields::secure_column::SecureColumnByCoords;
|
||||
use crate::core::fri::FriOps;
|
||||
use crate::core::poly::circle::{CanonicCoset, SecureEvaluation};
|
||||
use crate::core::poly::BitReversedOrder;
|
||||
use crate::m31;
|
||||
|
||||
#[test]
|
||||
fn decompose_coeff_out_fft_space_test() {
|
||||
for domain_log_size in 5..12 {
|
||||
let domain_log_half_size = domain_log_size - 1;
|
||||
let s = CanonicCoset::new(domain_log_size);
|
||||
let domain = s.circle_domain();
|
||||
|
||||
let mut coeffs = vec![BaseField::zero(); 1 << domain_log_size];
|
||||
|
||||
// Polynomial is out of FFT space.
|
||||
coeffs[1 << domain_log_half_size] = m31!(1);
|
||||
assert!(!CpuCirclePoly::new(coeffs.clone()).is_in_fft_space(domain_log_half_size));
|
||||
|
||||
let poly = CpuCirclePoly::new(coeffs);
|
||||
let values = poly.evaluate(domain);
|
||||
let secure_column = SecureColumnByCoords {
|
||||
columns: [
|
||||
values.values.clone(),
|
||||
values.values.clone(),
|
||||
values.values.clone(),
|
||||
values.values.clone(),
|
||||
],
|
||||
};
|
||||
let secure_eval = SecureEvaluation::<CpuBackend, BitReversedOrder>::new(
|
||||
domain,
|
||||
secure_column.clone(),
|
||||
);
|
||||
|
||||
let (g, lambda) = CpuBackend::decompose(&secure_eval);
|
||||
|
||||
// Sanity check.
|
||||
assert_ne!(lambda, SecureField::zero());
|
||||
|
||||
// Assert the new polynomial is in the FFT space.
|
||||
for i in 0..4 {
|
||||
let basefield_column = g.columns[i].clone();
|
||||
let eval = CpuCircleEvaluation::new(domain, basefield_column);
|
||||
let coeffs = eval.interpolate().coeffs;
|
||||
assert!(CpuCirclePoly::new(coeffs).is_in_fft_space(domain_log_half_size));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
18
Stwo_wrapper/crates/prover/src/core/backend/cpu/grind.rs
Normal file
18
Stwo_wrapper/crates/prover/src/core/backend/cpu/grind.rs
Normal file
@ -0,0 +1,18 @@
|
||||
use super::CpuBackend;
|
||||
use crate::core::channel::Channel;
|
||||
use crate::core::proof_of_work::GrindOps;
|
||||
|
||||
impl<C: Channel> GrindOps<C> for CpuBackend {
|
||||
fn grind(channel: &C, pow_bits: u32) -> u64 {
|
||||
// TODO(spapini): This is a naive implementation. Optimize it.
|
||||
let mut nonce = 0;
|
||||
loop {
|
||||
let mut channel = channel.clone();
|
||||
channel.mix_nonce(nonce);
|
||||
if channel.trailing_zeros() >= pow_bits {
|
||||
return nonce;
|
||||
}
|
||||
nonce += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
448
Stwo_wrapper/crates/prover/src/core/backend/cpu/lookups/gkr.rs
Normal file
448
Stwo_wrapper/crates/prover/src/core/backend/cpu/lookups/gkr.rs
Normal file
@ -0,0 +1,448 @@
|
||||
use std::ops::Index;
|
||||
|
||||
use num_traits::{One, Zero};
|
||||
|
||||
use crate::core::backend::CpuBackend;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::fields::{ExtensionOf, Field};
|
||||
use crate::core::lookups::gkr_prover::{
|
||||
correct_sum_as_poly_in_first_variable, EqEvals, GkrMultivariatePolyOracle, GkrOps, Layer,
|
||||
};
|
||||
use crate::core::lookups::mle::{Mle, MleOps};
|
||||
use crate::core::lookups::sumcheck::MultivariatePolyOracle;
|
||||
use crate::core::lookups::utils::{Fraction, Reciprocal, UnivariatePoly};
|
||||
|
||||
impl GkrOps for CpuBackend {
|
||||
fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Mle<Self, SecureField> {
|
||||
Mle::new(gen_eq_evals(y, v))
|
||||
}
|
||||
|
||||
fn next_layer(layer: &Layer<Self>) -> Layer<Self> {
|
||||
match layer {
|
||||
Layer::GrandProduct(layer) => next_grand_product_layer(layer),
|
||||
Layer::LogUpGeneric {
|
||||
numerators,
|
||||
denominators,
|
||||
} => next_logup_layer(MleExpr::Mle(numerators), denominators),
|
||||
Layer::LogUpMultiplicities {
|
||||
numerators,
|
||||
denominators,
|
||||
} => next_logup_layer(MleExpr::Mle(numerators), denominators),
|
||||
Layer::LogUpSingles { denominators } => {
|
||||
next_logup_layer(MleExpr::Constant(BaseField::one()), denominators)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn sum_as_poly_in_first_variable(
|
||||
h: &GkrMultivariatePolyOracle<'_, Self>,
|
||||
claim: SecureField,
|
||||
) -> UnivariatePoly<SecureField> {
|
||||
let n_variables = h.n_variables();
|
||||
assert!(!n_variables.is_zero());
|
||||
let n_terms = 1 << (n_variables - 1);
|
||||
let eq_evals = h.eq_evals.as_ref();
|
||||
// Vector used to generate evaluations of `eq(x, y)` for `x` in the boolean hypercube.
|
||||
let y = eq_evals.y();
|
||||
let lambda = h.lambda;
|
||||
|
||||
let (mut eval_at_0, mut eval_at_2) = match &h.input_layer {
|
||||
Layer::GrandProduct(col) => eval_grand_product_sum(eq_evals, col, n_terms),
|
||||
Layer::LogUpGeneric {
|
||||
numerators,
|
||||
denominators,
|
||||
} => eval_logup_sum(eq_evals, numerators, denominators, n_terms, lambda),
|
||||
Layer::LogUpMultiplicities {
|
||||
numerators,
|
||||
denominators,
|
||||
} => eval_logup_sum(eq_evals, numerators, denominators, n_terms, lambda),
|
||||
Layer::LogUpSingles { denominators } => {
|
||||
eval_logup_singles_sum(eq_evals, denominators, n_terms, lambda)
|
||||
}
|
||||
};
|
||||
|
||||
eval_at_0 *= h.eq_fixed_var_correction;
|
||||
eval_at_2 *= h.eq_fixed_var_correction;
|
||||
correct_sum_as_poly_in_first_variable(eval_at_0, eval_at_2, claim, y, n_variables)
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluates `sum_x eq(({0}^|r|, 0, x), y) * inp(r, t, x, 0) * inp(r, t, x, 1)` at `t=0` and `t=2`.
|
||||
///
|
||||
/// Output of the form: `(eval_at_0, eval_at_2)`.
|
||||
fn eval_grand_product_sum(
|
||||
eq_evals: &EqEvals<CpuBackend>,
|
||||
input_layer: &Mle<CpuBackend, SecureField>,
|
||||
n_terms: usize,
|
||||
) -> (SecureField, SecureField) {
|
||||
let mut eval_at_0 = SecureField::zero();
|
||||
let mut eval_at_2 = SecureField::zero();
|
||||
|
||||
for i in 0..n_terms {
|
||||
// Input polynomial at points `(r, {0, 1, 2}, bits(i), {0, 1})`.
|
||||
let inp_at_r0i0 = input_layer[i * 2];
|
||||
let inp_at_r0i1 = input_layer[i * 2 + 1];
|
||||
let inp_at_r1i0 = input_layer[(n_terms + i) * 2];
|
||||
let inp_at_r1i1 = input_layer[(n_terms + i) * 2 + 1];
|
||||
// Note `inp(r, t, x) = eq(t, 0) * inp(r, 0, x) + eq(t, 1) * inp(r, 1, x)`
|
||||
// => `inp(r, 2, x) = 2 * inp(r, 1, x) - inp(r, 0, x)`
|
||||
// TODO(andrew): Consider evaluation at `1/2` to save an addition operation since
|
||||
// `inp(r, 1/2, x) = 1/2 * (inp(r, 1, x) + inp(r, 0, x))`. `1/2 * ...` can be factored
|
||||
// outside the loop.
|
||||
let inp_at_r2i0 = inp_at_r1i0.double() - inp_at_r0i0;
|
||||
let inp_at_r2i1 = inp_at_r1i1.double() - inp_at_r0i1;
|
||||
|
||||
// Product polynomial `prod(x) = inp(x, 0) * inp(x, 1)` at points `(r, {0, 2}, bits(i))`.
|
||||
let prod_at_r2i = inp_at_r2i0 * inp_at_r2i1;
|
||||
let prod_at_r0i = inp_at_r0i0 * inp_at_r0i1;
|
||||
|
||||
let eq_eval_at_0i = eq_evals[i];
|
||||
eval_at_0 += eq_eval_at_0i * prod_at_r0i;
|
||||
eval_at_2 += eq_eval_at_0i * prod_at_r2i;
|
||||
}
|
||||
|
||||
(eval_at_0, eval_at_2)
|
||||
}
|
||||
|
||||
/// Evaluates `sum_x eq(({0}^|r|, 0, x), y) * (inp_numer(r, t, x, 0) * inp_denom(r, t, x, 1) +
|
||||
/// inp_numer(r, t, x, 1) * inp_denom(r, t, x, 0) + lambda * inp_denom(r, t, x, 0) * inp_denom(r, t,
|
||||
/// x, 1))` at `t=0` and `t=2`.
|
||||
///
|
||||
/// Output of the form: `(eval_at_0, eval_at_2)`.
|
||||
fn eval_logup_sum<F: Field>(
|
||||
eq_evals: &EqEvals<CpuBackend>,
|
||||
input_numerators: &Mle<CpuBackend, F>,
|
||||
input_denominators: &Mle<CpuBackend, SecureField>,
|
||||
n_terms: usize,
|
||||
lambda: SecureField,
|
||||
) -> (SecureField, SecureField)
|
||||
where
|
||||
SecureField: ExtensionOf<F> + Field,
|
||||
{
|
||||
let mut eval_at_0 = SecureField::zero();
|
||||
let mut eval_at_2 = SecureField::zero();
|
||||
|
||||
for i in 0..n_terms {
|
||||
// Input polynomials at points `(r, {0, 1, 2}, bits(i), {0, 1})`.
|
||||
let inp_numer_at_r0i0 = input_numerators[i * 2];
|
||||
let inp_denom_at_r0i0 = input_denominators[i * 2];
|
||||
let inp_numer_at_r0i1 = input_numerators[i * 2 + 1];
|
||||
let inp_denom_at_r0i1 = input_denominators[i * 2 + 1];
|
||||
let inp_numer_at_r1i0 = input_numerators[(n_terms + i) * 2];
|
||||
let inp_denom_at_r1i0 = input_denominators[(n_terms + i) * 2];
|
||||
let inp_numer_at_r1i1 = input_numerators[(n_terms + i) * 2 + 1];
|
||||
let inp_denom_at_r1i1 = input_denominators[(n_terms + i) * 2 + 1];
|
||||
// Note `inp_denom(r, t, x) = eq(t, 0) * inp_denom(r, 0, x) + eq(t, 1) * inp_denom(r, 1, x)`
|
||||
// => `inp_denom(r, 2, x) = 2 * inp_denom(r, 1, x) - inp_denom(r, 0, x)`
|
||||
let inp_numer_at_r2i0 = inp_numer_at_r1i0.double() - inp_numer_at_r0i0;
|
||||
let inp_denom_at_r2i0 = inp_denom_at_r1i0.double() - inp_denom_at_r0i0;
|
||||
let inp_numer_at_r2i1 = inp_numer_at_r1i1.double() - inp_numer_at_r0i1;
|
||||
let inp_denom_at_r2i1 = inp_denom_at_r1i1.double() - inp_denom_at_r0i1;
|
||||
|
||||
// Fraction addition polynomials:
|
||||
// - `numer(x) = inp_numer(x, 0) * inp_denom(x, 1) + inp_numer(x, 1) * inp_denom(x, 0)`
|
||||
// - `denom(x) = inp_denom(x, 1) * inp_denom(x, 0)`
|
||||
// at points `(r, {0, 2}, bits(i))`.
|
||||
let Fraction {
|
||||
numerator: numer_at_r0i,
|
||||
denominator: denom_at_r0i,
|
||||
} = Fraction::new(inp_numer_at_r0i0, inp_denom_at_r0i0)
|
||||
+ Fraction::new(inp_numer_at_r0i1, inp_denom_at_r0i1);
|
||||
let Fraction {
|
||||
numerator: numer_at_r2i,
|
||||
denominator: denom_at_r2i,
|
||||
} = Fraction::new(inp_numer_at_r2i0, inp_denom_at_r2i0)
|
||||
+ Fraction::new(inp_numer_at_r2i1, inp_denom_at_r2i1);
|
||||
|
||||
let eq_eval_at_0i = eq_evals[i];
|
||||
eval_at_0 += eq_eval_at_0i * (numer_at_r0i + lambda * denom_at_r0i);
|
||||
eval_at_2 += eq_eval_at_0i * (numer_at_r2i + lambda * denom_at_r2i);
|
||||
}
|
||||
|
||||
(eval_at_0, eval_at_2)
|
||||
}
|
||||
|
||||
/// Evaluates `sum_x eq(({0}^|r|, 0, x), y) * (inp_denom(r, t, x, 1) + inp_denom(r, t, x, 0) +
|
||||
/// lambda * inp_denom(r, t, x, 0) * inp_denom(r, t, x, 1))` at `t=0` and `t=2`.
|
||||
///
|
||||
/// Output of the form: `(eval_at_0, eval_at_2)`.
|
||||
fn eval_logup_singles_sum(
|
||||
eq_evals: &EqEvals<CpuBackend>,
|
||||
input_denominators: &Mle<CpuBackend, SecureField>,
|
||||
n_terms: usize,
|
||||
lambda: SecureField,
|
||||
) -> (SecureField, SecureField) {
|
||||
let mut eval_at_0 = SecureField::zero();
|
||||
let mut eval_at_2 = SecureField::zero();
|
||||
|
||||
for i in 0..n_terms {
|
||||
// Input polynomial at points `(r, {0, 1, 2}, bits(i), {0, 1})`.
|
||||
let inp_denom_at_r0i0 = input_denominators[i * 2];
|
||||
let inp_denom_at_r0i1 = input_denominators[i * 2 + 1];
|
||||
let inp_denom_at_r1i0 = input_denominators[(n_terms + i) * 2];
|
||||
let inp_denom_at_r1i1 = input_denominators[(n_terms + i) * 2 + 1];
|
||||
// Note `inp_denom(r, t, x) = eq(t, 0) * inp_denom(r, 0, x) + eq(t, 1) * inp_denom(r, 1, x)`
|
||||
// => `inp_denom(r, 2, x) = 2 * inp_denom(r, 1, x) - inp_denom(r, 0, x)`
|
||||
let inp_denom_at_r2i0 = inp_denom_at_r1i0.double() - inp_denom_at_r0i0;
|
||||
let inp_denom_at_r2i1 = inp_denom_at_r1i1.double() - inp_denom_at_r0i1;
|
||||
|
||||
// Fraction addition polynomials at points:
|
||||
// - `numer(x) = inp_denom(x, 1) + inp_denom(x, 0)`
|
||||
// - `denom(x) = inp_denom(x, 1) * inp_denom(x, 0)`
|
||||
// at points `(r, {0, 2}, bits(i))`.
|
||||
let Fraction {
|
||||
numerator: numer_at_r0i,
|
||||
denominator: denom_at_r0i,
|
||||
} = Reciprocal::new(inp_denom_at_r0i0) + Reciprocal::new(inp_denom_at_r0i1);
|
||||
let Fraction {
|
||||
numerator: numer_at_r2i,
|
||||
denominator: denom_at_r2i,
|
||||
} = Reciprocal::new(inp_denom_at_r2i0) + Reciprocal::new(inp_denom_at_r2i1);
|
||||
|
||||
let eq_eval_at_0i = eq_evals[i];
|
||||
eval_at_0 += eq_eval_at_0i * (numer_at_r0i + lambda * denom_at_r0i);
|
||||
eval_at_2 += eq_eval_at_0i * (numer_at_r2i + lambda * denom_at_r2i);
|
||||
}
|
||||
|
||||
(eval_at_0, eval_at_2)
|
||||
}
|
||||
|
||||
/// Returns evaluations `eq(x, y) * v` for all `x` in `{0, 1}^n`.
|
||||
///
|
||||
/// Evaluations are returned in bit-reversed order.
|
||||
pub fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Vec<SecureField> {
|
||||
let mut evals = Vec::with_capacity(1 << y.len());
|
||||
evals.push(v);
|
||||
|
||||
for &y_i in y.iter().rev() {
|
||||
for j in 0..evals.len() {
|
||||
// `lhs[j] = eq(0, y_i) * c[i]`
|
||||
// `rhs[j] = eq(1, y_i) * c[i]`
|
||||
let tmp = evals[j] * y_i;
|
||||
evals.push(tmp);
|
||||
evals[j] -= tmp;
|
||||
}
|
||||
}
|
||||
|
||||
evals
|
||||
}
|
||||
|
||||
fn next_grand_product_layer(layer: &Mle<CpuBackend, SecureField>) -> Layer<CpuBackend> {
|
||||
let res = layer.array_chunks().map(|&[a, b]| a * b).collect();
|
||||
Layer::GrandProduct(Mle::new(res))
|
||||
}
|
||||
|
||||
fn next_logup_layer<F>(
|
||||
numerators: MleExpr<'_, F>,
|
||||
denominators: &Mle<CpuBackend, SecureField>,
|
||||
) -> Layer<CpuBackend>
|
||||
where
|
||||
F: Field,
|
||||
SecureField: ExtensionOf<F>,
|
||||
CpuBackend: MleOps<F>,
|
||||
{
|
||||
let half_n = 1 << (denominators.n_variables() - 1);
|
||||
let mut next_numerators = Vec::with_capacity(half_n);
|
||||
let mut next_denominators = Vec::with_capacity(half_n);
|
||||
|
||||
for i in 0..half_n {
|
||||
let a = Fraction::new(numerators[i * 2], denominators[i * 2]);
|
||||
let b = Fraction::new(numerators[i * 2 + 1], denominators[i * 2 + 1]);
|
||||
let res = a + b;
|
||||
next_numerators.push(res.numerator);
|
||||
next_denominators.push(res.denominator);
|
||||
}
|
||||
|
||||
Layer::LogUpGeneric {
|
||||
numerators: Mle::new(next_numerators),
|
||||
denominators: Mle::new(next_denominators),
|
||||
}
|
||||
}
|
||||
|
||||
enum MleExpr<'a, F: Field> {
|
||||
Constant(F),
|
||||
Mle(&'a Mle<CpuBackend, F>),
|
||||
}
|
||||
|
||||
impl<'a, F: Field> Index<usize> for MleExpr<'a, F> {
|
||||
type Output = F;
|
||||
|
||||
fn index(&self, index: usize) -> &F {
|
||||
match self {
|
||||
Self::Constant(v) => v,
|
||||
Self::Mle(mle) => &mle[index],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::iter::zip;
|
||||
|
||||
use num_traits::{One, Zero};
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
|
||||
use crate::core::backend::CpuBackend;
|
||||
use crate::core::channel::Channel;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::lookups::gkr_prover::{prove_batch, GkrOps, Layer};
|
||||
use crate::core::lookups::gkr_verifier::{partially_verify_batch, Gate, GkrArtifact, GkrError};
|
||||
use crate::core::lookups::mle::Mle;
|
||||
use crate::core::lookups::utils::{eq, Fraction};
|
||||
use crate::core::test_utils::test_channel;
|
||||
|
||||
#[test]
|
||||
fn gen_eq_evals() {
|
||||
let zero = SecureField::zero();
|
||||
let one = SecureField::one();
|
||||
let two = BaseField::from(2).into();
|
||||
let y = [7, 3].map(|v| BaseField::from(v).into());
|
||||
|
||||
let eq_evals = CpuBackend::gen_eq_evals(&y, two);
|
||||
|
||||
assert_eq!(
|
||||
*eq_evals,
|
||||
[
|
||||
eq(&[zero, zero], &y) * two,
|
||||
eq(&[zero, one], &y) * two,
|
||||
eq(&[one, zero], &y) * two,
|
||||
eq(&[one, one], &y) * two,
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn grand_product_works() -> Result<(), GkrError> {
|
||||
const N: usize = 1 << 5;
|
||||
let values = test_channel().draw_felts(N);
|
||||
let product = values.iter().product::<SecureField>();
|
||||
let col = Mle::<CpuBackend, SecureField>::new(values);
|
||||
let input_layer = Layer::GrandProduct(col.clone());
|
||||
let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]);
|
||||
|
||||
let GkrArtifact {
|
||||
ood_point: r,
|
||||
claims_to_verify_by_instance,
|
||||
n_variables_by_instance: _,
|
||||
} = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut test_channel())?;
|
||||
|
||||
assert_eq!(proof.output_claims_by_instance, [vec![product]]);
|
||||
assert_eq!(claims_to_verify_by_instance, [vec![col.eval_at_point(&r)]]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn logup_with_generic_trace_works() -> Result<(), GkrError> {
|
||||
const N: usize = 1 << 5;
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let numerator_values = (0..N).map(|_| rng.gen()).collect::<Vec<SecureField>>();
|
||||
let denominator_values = (0..N).map(|_| rng.gen()).collect::<Vec<SecureField>>();
|
||||
let sum = zip(&numerator_values, &denominator_values)
|
||||
.map(|(&n, &d)| Fraction::new(n, d))
|
||||
.sum::<Fraction<SecureField, SecureField>>();
|
||||
let numerators = Mle::<CpuBackend, SecureField>::new(numerator_values);
|
||||
let denominators = Mle::<CpuBackend, SecureField>::new(denominator_values);
|
||||
let top_layer = Layer::LogUpGeneric {
|
||||
numerators: numerators.clone(),
|
||||
denominators: denominators.clone(),
|
||||
};
|
||||
let (proof, _) = prove_batch(&mut test_channel(), vec![top_layer]);
|
||||
|
||||
let GkrArtifact {
|
||||
ood_point,
|
||||
claims_to_verify_by_instance,
|
||||
n_variables_by_instance: _,
|
||||
} = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?;
|
||||
|
||||
assert_eq!(claims_to_verify_by_instance.len(), 1);
|
||||
assert_eq!(proof.output_claims_by_instance.len(), 1);
|
||||
assert_eq!(
|
||||
claims_to_verify_by_instance[0],
|
||||
[
|
||||
numerators.eval_at_point(&ood_point),
|
||||
denominators.eval_at_point(&ood_point)
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
proof.output_claims_by_instance[0],
|
||||
[sum.numerator, sum.denominator]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn logup_with_singles_trace_works() -> Result<(), GkrError> {
|
||||
const N: usize = 1 << 5;
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let denominator_values = (0..N).map(|_| rng.gen()).collect::<Vec<SecureField>>();
|
||||
let sum = denominator_values
|
||||
.iter()
|
||||
.map(|&d| Fraction::new(SecureField::one(), d))
|
||||
.sum::<Fraction<SecureField, SecureField>>();
|
||||
let denominators = Mle::<CpuBackend, SecureField>::new(denominator_values);
|
||||
let top_layer = Layer::LogUpSingles {
|
||||
denominators: denominators.clone(),
|
||||
};
|
||||
let (proof, _) = prove_batch(&mut test_channel(), vec![top_layer]);
|
||||
|
||||
let GkrArtifact {
|
||||
ood_point,
|
||||
claims_to_verify_by_instance,
|
||||
n_variables_by_instance: _,
|
||||
} = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?;
|
||||
|
||||
assert_eq!(claims_to_verify_by_instance.len(), 1);
|
||||
assert_eq!(proof.output_claims_by_instance.len(), 1);
|
||||
assert_eq!(
|
||||
claims_to_verify_by_instance[0],
|
||||
[SecureField::one(), denominators.eval_at_point(&ood_point)]
|
||||
);
|
||||
assert_eq!(
|
||||
proof.output_claims_by_instance[0],
|
||||
[sum.numerator, sum.denominator]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn logup_with_multiplicities_trace_works() -> Result<(), GkrError> {
|
||||
const N: usize = 1 << 5;
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let numerator_values = (0..N).map(|_| rng.gen()).collect::<Vec<BaseField>>();
|
||||
let denominator_values = (0..N).map(|_| rng.gen()).collect::<Vec<SecureField>>();
|
||||
let sum = zip(&numerator_values, &denominator_values)
|
||||
.map(|(&n, &d)| Fraction::new(n.into(), d))
|
||||
.sum::<Fraction<SecureField, SecureField>>();
|
||||
let numerators = Mle::<CpuBackend, BaseField>::new(numerator_values);
|
||||
let denominators = Mle::<CpuBackend, SecureField>::new(denominator_values);
|
||||
let top_layer = Layer::LogUpMultiplicities {
|
||||
numerators: numerators.clone(),
|
||||
denominators: denominators.clone(),
|
||||
};
|
||||
let (proof, _) = prove_batch(&mut test_channel(), vec![top_layer]);
|
||||
|
||||
let GkrArtifact {
|
||||
ood_point,
|
||||
claims_to_verify_by_instance,
|
||||
n_variables_by_instance: _,
|
||||
} = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?;
|
||||
|
||||
assert_eq!(claims_to_verify_by_instance.len(), 1);
|
||||
assert_eq!(proof.output_claims_by_instance.len(), 1);
|
||||
assert_eq!(
|
||||
claims_to_verify_by_instance[0],
|
||||
[
|
||||
numerators.eval_at_point(&ood_point),
|
||||
denominators.eval_at_point(&ood_point)
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
proof.output_claims_by_instance[0],
|
||||
[sum.numerator, sum.denominator]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,66 @@
|
||||
use std::iter::zip;
|
||||
|
||||
use num_traits::{One, Zero};
|
||||
|
||||
use crate::core::backend::CpuBackend;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::lookups::mle::{Mle, MleOps};
|
||||
use crate::core::lookups::sumcheck::MultivariatePolyOracle;
|
||||
use crate::core::lookups::utils::{fold_mle_evals, UnivariatePoly};
|
||||
|
||||
impl MleOps<BaseField> for CpuBackend {
|
||||
fn fix_first_variable(
|
||||
mle: Mle<Self, BaseField>,
|
||||
assignment: SecureField,
|
||||
) -> Mle<Self, SecureField> {
|
||||
let midpoint = mle.len() / 2;
|
||||
let (lhs_evals, rhs_evals) = mle.split_at(midpoint);
|
||||
|
||||
let res = zip(lhs_evals, rhs_evals)
|
||||
.map(|(&lhs_eval, &rhs_eval)| fold_mle_evals(assignment, lhs_eval, rhs_eval))
|
||||
.collect();
|
||||
|
||||
Mle::new(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl MleOps<SecureField> for CpuBackend {
|
||||
fn fix_first_variable(
|
||||
mle: Mle<Self, SecureField>,
|
||||
assignment: SecureField,
|
||||
) -> Mle<Self, SecureField> {
|
||||
let midpoint = mle.len() / 2;
|
||||
let mut evals = mle.into_evals();
|
||||
|
||||
for i in 0..midpoint {
|
||||
let lhs_eval = evals[i];
|
||||
let rhs_eval = evals[i + midpoint];
|
||||
evals[i] = fold_mle_evals(assignment, lhs_eval, rhs_eval);
|
||||
}
|
||||
|
||||
evals.truncate(midpoint);
|
||||
|
||||
Mle::new(evals)
|
||||
}
|
||||
}
|
||||
|
||||
impl MultivariatePolyOracle for Mle<CpuBackend, SecureField> {
|
||||
fn n_variables(&self) -> usize {
|
||||
self.n_variables()
|
||||
}
|
||||
|
||||
fn sum_as_poly_in_first_variable(&self, claim: SecureField) -> UnivariatePoly<SecureField> {
|
||||
let x0 = SecureField::zero();
|
||||
let x1 = SecureField::one();
|
||||
|
||||
let y0 = self[0..self.len() / 2].iter().sum();
|
||||
let y1 = claim - y0;
|
||||
|
||||
UnivariatePoly::interpolate_lagrange(&[x0, x1], &[y0, y1])
|
||||
}
|
||||
|
||||
fn fix_first_variable(self, challenge: SecureField) -> Self {
|
||||
self.fix_first_variable(challenge)
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,2 @@
|
||||
pub mod gkr;
|
||||
mod mle;
|
||||
105
Stwo_wrapper/crates/prover/src/core/backend/cpu/mod.rs
Normal file
105
Stwo_wrapper/crates/prover/src/core/backend/cpu/mod.rs
Normal file
@ -0,0 +1,105 @@
|
||||
mod accumulation;
|
||||
mod blake2s;
|
||||
mod circle;
|
||||
mod fri;
|
||||
mod grind;
|
||||
pub mod lookups;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
mod poseidon252;
|
||||
pub mod quotients;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
mod poseidon_bls;
|
||||
|
||||
use std::fmt::Debug;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::{Backend, BackendForChannel, Column, ColumnOps, FieldOps};
|
||||
use crate::core::fields::Field;
|
||||
use crate::core::lookups::mle::Mle;
|
||||
use crate::core::poly::circle::{CircleEvaluation, CirclePoly};
|
||||
use crate::core::utils::bit_reverse;
|
||||
use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::core::vcs::poseidon252_merkle::Poseidon252MerkleChannel;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::core::vcs::poseidon_bls_merkle::PoseidonBLSMerkleChannel;
|
||||
|
||||
#[derive(Copy, Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct CpuBackend;
|
||||
|
||||
impl Backend for CpuBackend {}
|
||||
impl BackendForChannel<Blake2sMerkleChannel> for CpuBackend {}
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
impl BackendForChannel<Poseidon252MerkleChannel> for CpuBackend {}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
impl BackendForChannel<PoseidonBLSMerkleChannel> for CpuBackend {}
|
||||
|
||||
impl<T: Debug + Clone + Default> ColumnOps<T> for CpuBackend {
|
||||
type Column = Vec<T>;
|
||||
|
||||
fn bit_reverse_column(column: &mut Self::Column) {
|
||||
bit_reverse(column)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> FieldOps<F> for CpuBackend {
|
||||
/// Batch inversion using the Montgomery's trick.
|
||||
// TODO(Ohad): Benchmark this function.
|
||||
fn batch_inverse(column: &Self::Column, dst: &mut Self::Column) {
|
||||
F::batch_inverse(column, &mut dst[..]);
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Debug + Clone + Default> Column<T> for Vec<T> {
|
||||
fn zeros(len: usize) -> Self {
|
||||
vec![T::default(); len]
|
||||
}
|
||||
#[allow(clippy::uninit_vec)]
|
||||
unsafe fn uninitialized(length: usize) -> Self {
|
||||
let mut data = Vec::with_capacity(length);
|
||||
data.set_len(length);
|
||||
data
|
||||
}
|
||||
fn to_cpu(&self) -> Vec<T> {
|
||||
self.clone()
|
||||
}
|
||||
fn len(&self) -> usize {
|
||||
self.len()
|
||||
}
|
||||
fn at(&self, index: usize) -> T {
|
||||
self[index].clone()
|
||||
}
|
||||
fn set(&mut self, index: usize, value: T) {
|
||||
self[index] = value;
|
||||
}
|
||||
}
|
||||
|
||||
pub type CpuCirclePoly = CirclePoly<CpuBackend>;
|
||||
pub type CpuCircleEvaluation<F, EvalOrder> = CircleEvaluation<CpuBackend, F, EvalOrder>;
|
||||
pub type CpuMle<F> = Mle<CpuBackend, F>;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use itertools::Itertools;
|
||||
use rand::prelude::*;
|
||||
use rand::rngs::SmallRng;
|
||||
|
||||
use crate::core::backend::{Column, CpuBackend, FieldOps};
|
||||
use crate::core::fields::qm31::QM31;
|
||||
use crate::core::fields::FieldExpOps;
|
||||
|
||||
#[test]
|
||||
fn batch_inverse_test() {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let column = rng.gen::<[QM31; 16]>().to_vec();
|
||||
let expected = column.iter().map(|e| e.inverse()).collect_vec();
|
||||
let mut dst = Column::zeros(column.len());
|
||||
|
||||
CpuBackend::batch_inverse(&column, &mut dst);
|
||||
|
||||
assert_eq!(expected, dst);
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,24 @@
|
||||
use itertools::Itertools;
|
||||
use starknet_ff::FieldElement as FieldElement252;
|
||||
|
||||
use super::CpuBackend;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::vcs::ops::{MerkleHasher, MerkleOps};
|
||||
use crate::core::vcs::poseidon252_merkle::Poseidon252MerkleHasher;
|
||||
|
||||
impl MerkleOps<Poseidon252MerkleHasher> for CpuBackend {
|
||||
fn commit_on_layer(
|
||||
log_size: u32,
|
||||
prev_layer: Option<&Vec<FieldElement252>>,
|
||||
columns: &[&Vec<BaseField>],
|
||||
) -> Vec<FieldElement252> {
|
||||
(0..(1 << log_size))
|
||||
.map(|i| {
|
||||
Poseidon252MerkleHasher::hash_node(
|
||||
prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])),
|
||||
&columns.iter().map(|column| column[i]).collect_vec(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,24 @@
|
||||
use itertools::Itertools;
|
||||
use ark_bls12_381::Fr as BlsFr;
|
||||
|
||||
use super::CpuBackend;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::vcs::ops::{MerkleHasher, MerkleOps};
|
||||
use crate::core::vcs::poseidon_bls_merkle::PoseidonBLSMerkleHasher;
|
||||
|
||||
impl MerkleOps<PoseidonBLSMerkleHasher> for CpuBackend {
|
||||
fn commit_on_layer(
|
||||
log_size: u32,
|
||||
prev_layer: Option<&Vec<BlsFr>>,
|
||||
columns: &[&Vec<BaseField>],
|
||||
) -> Vec<BlsFr> {
|
||||
(0..(1 << log_size))
|
||||
.map(|i| {
|
||||
PoseidonBLSMerkleHasher::hash_node(
|
||||
prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])),
|
||||
&columns.iter().map(|column| column[i]).collect_vec(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
210
Stwo_wrapper/crates/prover/src/core/backend/cpu/quotients.rs
Normal file
210
Stwo_wrapper/crates/prover/src/core/backend/cpu/quotients.rs
Normal file
@ -0,0 +1,210 @@
|
||||
use itertools::{izip, zip_eq};
|
||||
use num_traits::{One, Zero};
|
||||
|
||||
use super::CpuBackend;
|
||||
use crate::core::circle::CirclePoint;
|
||||
use crate::core::constraints::complex_conjugate_line_coeffs;
|
||||
use crate::core::fields::cm31::CM31;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::fields::secure_column::SecureColumnByCoords;
|
||||
use crate::core::fields::FieldExpOps;
|
||||
use crate::core::pcs::quotients::{ColumnSampleBatch, PointSample, QuotientOps};
|
||||
use crate::core::poly::circle::{CircleDomain, CircleEvaluation, SecureEvaluation};
|
||||
use crate::core::poly::BitReversedOrder;
|
||||
use crate::core::utils::{bit_reverse, bit_reverse_index};
|
||||
|
||||
impl QuotientOps for CpuBackend {
|
||||
fn accumulate_quotients(
|
||||
domain: CircleDomain,
|
||||
columns: &[&CircleEvaluation<Self, BaseField, BitReversedOrder>],
|
||||
random_coeff: SecureField,
|
||||
sample_batches: &[ColumnSampleBatch],
|
||||
_log_blowup_factor: u32,
|
||||
) -> SecureEvaluation<Self, BitReversedOrder> {
|
||||
let mut values = unsafe { SecureColumnByCoords::uninitialized(domain.size()) };
|
||||
let quotient_constants = quotient_constants(sample_batches, random_coeff, domain);
|
||||
|
||||
for row in 0..domain.size() {
|
||||
let domain_point = domain.at(bit_reverse_index(row, domain.log_size()));
|
||||
let row_value = accumulate_row_quotients(
|
||||
sample_batches,
|
||||
columns,
|
||||
"ient_constants,
|
||||
row,
|
||||
domain_point,
|
||||
);
|
||||
values.set(row, row_value);
|
||||
}
|
||||
SecureEvaluation::new(domain, values)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn accumulate_row_quotients(
|
||||
sample_batches: &[ColumnSampleBatch],
|
||||
columns: &[&CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>],
|
||||
quotient_constants: &QuotientConstants,
|
||||
row: usize,
|
||||
domain_point: CirclePoint<BaseField>,
|
||||
) -> SecureField {
|
||||
let mut row_accumulator = SecureField::zero();
|
||||
for (sample_batch, line_coeffs, batch_coeff, denominator_inverses) in izip!(
|
||||
sample_batches,
|
||||
"ient_constants.line_coeffs,
|
||||
"ient_constants.batch_random_coeffs,
|
||||
"ient_constants.denominator_inverses
|
||||
) {
|
||||
let mut numerator = SecureField::zero();
|
||||
for ((column_index, _), (a, b, c)) in zip_eq(&sample_batch.columns_and_values, line_coeffs)
|
||||
{
|
||||
let column = &columns[*column_index];
|
||||
let value = column[row] * *c;
|
||||
// The numerator is a line equation passing through
|
||||
// (sample_point.y, sample_value), (conj(sample_point), conj(sample_value))
|
||||
// evaluated at (domain_point.y, value).
|
||||
// When substituting a polynomial in this line equation, we get a polynomial with a root
|
||||
// at sample_point and conj(sample_point) if the original polynomial had the values
|
||||
// sample_value and conj(sample_value) at these points.
|
||||
let linear_term = *a * domain_point.y + *b;
|
||||
numerator += value - linear_term;
|
||||
}
|
||||
|
||||
row_accumulator =
|
||||
row_accumulator * *batch_coeff + numerator.mul_cm31(denominator_inverses[row]);
|
||||
}
|
||||
row_accumulator
|
||||
}
|
||||
|
||||
/// Precompute the complex conjugate line coefficients for each column in each sample batch.
|
||||
/// Specifically, for the i-th (in a sample batch) column's numerator term
|
||||
/// `alpha^i * (c * F(p) - (a * p.y + b))`, we precompute and return the constants:
|
||||
/// (`alpha^i * a`, `alpha^i * b`, `alpha^i * c`).
|
||||
pub fn column_line_coeffs(
|
||||
sample_batches: &[ColumnSampleBatch],
|
||||
random_coeff: SecureField,
|
||||
) -> Vec<Vec<(SecureField, SecureField, SecureField)>> {
|
||||
sample_batches
|
||||
.iter()
|
||||
.map(|sample_batch| {
|
||||
let mut alpha = SecureField::one();
|
||||
sample_batch
|
||||
.columns_and_values
|
||||
.iter()
|
||||
.map(|(_, sampled_value)| {
|
||||
alpha *= random_coeff;
|
||||
let sample = PointSample {
|
||||
point: sample_batch.point,
|
||||
value: *sampled_value,
|
||||
};
|
||||
complex_conjugate_line_coeffs(&sample, alpha)
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Precompute the random coefficients used to linearly combine the batched quotients.
|
||||
/// Specifically, for each sample batch we compute random_coeff^(number of columns in the batch),
|
||||
/// which is used to linearly combine the batch with the next one.
|
||||
pub fn batch_random_coeffs(
|
||||
sample_batches: &[ColumnSampleBatch],
|
||||
random_coeff: SecureField,
|
||||
) -> Vec<SecureField> {
|
||||
sample_batches
|
||||
.iter()
|
||||
.map(|sb| random_coeff.pow(sb.columns_and_values.len() as u128))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn denominator_inverses(
|
||||
sample_batches: &[ColumnSampleBatch],
|
||||
domain: CircleDomain,
|
||||
) -> Vec<Vec<CM31>> {
|
||||
let mut flat_denominators = Vec::with_capacity(sample_batches.len() * domain.size());
|
||||
// We want a P to be on a line that passes through a point Pr + uPi in QM31^2, and its conjugate
|
||||
// Pr - uPi. Thus, Pr - P is parallel to Pi. Or, (Pr - P).x * Pi.y - (Pr - P).y * Pi.x = 0.
|
||||
for sample_batch in sample_batches {
|
||||
// Extract Pr, Pi.
|
||||
let prx = sample_batch.point.x.0;
|
||||
let pry = sample_batch.point.y.0;
|
||||
let pix = sample_batch.point.x.1;
|
||||
let piy = sample_batch.point.y.1;
|
||||
for row in 0..domain.size() {
|
||||
let domain_point = domain.at(row);
|
||||
flat_denominators.push((prx - domain_point.x) * piy - (pry - domain_point.y) * pix);
|
||||
}
|
||||
}
|
||||
|
||||
let mut flat_denominator_inverses = vec![CM31::zero(); flat_denominators.len()];
|
||||
CM31::batch_inverse(&flat_denominators, &mut flat_denominator_inverses);
|
||||
|
||||
flat_denominator_inverses
|
||||
.chunks_mut(domain.size())
|
||||
.map(|denominator_inverses| {
|
||||
bit_reverse(denominator_inverses);
|
||||
denominator_inverses.to_vec()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn quotient_constants(
|
||||
sample_batches: &[ColumnSampleBatch],
|
||||
random_coeff: SecureField,
|
||||
domain: CircleDomain,
|
||||
) -> QuotientConstants {
|
||||
let line_coeffs = column_line_coeffs(sample_batches, random_coeff);
|
||||
let batch_random_coeffs = batch_random_coeffs(sample_batches, random_coeff);
|
||||
let denominator_inverses = denominator_inverses(sample_batches, domain);
|
||||
QuotientConstants {
|
||||
line_coeffs,
|
||||
batch_random_coeffs,
|
||||
denominator_inverses,
|
||||
}
|
||||
}
|
||||
|
||||
/// Holds the precomputed constant values used in each quotient evaluation.
|
||||
pub struct QuotientConstants {
|
||||
/// The line coefficients for each quotient numerator term. For more details see
|
||||
/// [self::column_line_coeffs].
|
||||
pub line_coeffs: Vec<Vec<(SecureField, SecureField, SecureField)>>,
|
||||
/// The random coefficients used to linearly combine the batched quotients For more details see
|
||||
/// [self::batch_random_coeffs].
|
||||
pub batch_random_coeffs: Vec<SecureField>,
|
||||
/// The inverses of the denominators of the quotients.
|
||||
pub denominator_inverses: Vec<Vec<CM31>>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::core::backend::cpu::{CpuCircleEvaluation, CpuCirclePoly};
|
||||
use crate::core::backend::CpuBackend;
|
||||
use crate::core::circle::SECURE_FIELD_CIRCLE_GEN;
|
||||
use crate::core::pcs::quotients::{ColumnSampleBatch, QuotientOps};
|
||||
use crate::core::poly::circle::CanonicCoset;
|
||||
use crate::{m31, qm31};
|
||||
|
||||
#[test]
|
||||
fn test_quotients_are_low_degree() {
|
||||
const LOG_SIZE: u32 = 7;
|
||||
const LOG_BLOWUP_FACTOR: u32 = 1;
|
||||
let polynomial = CpuCirclePoly::new((0..1 << LOG_SIZE).map(|i| m31!(i)).collect());
|
||||
let eval_domain = CanonicCoset::new(LOG_SIZE + 1).circle_domain();
|
||||
let eval = polynomial.evaluate(eval_domain);
|
||||
let point = SECURE_FIELD_CIRCLE_GEN;
|
||||
let value = polynomial.eval_at_point(point);
|
||||
let coeff = qm31!(1, 2, 3, 4);
|
||||
let quot_eval = CpuBackend::accumulate_quotients(
|
||||
eval_domain,
|
||||
&[&eval],
|
||||
coeff,
|
||||
&[ColumnSampleBatch {
|
||||
point,
|
||||
columns_and_values: vec![(0, value)],
|
||||
}],
|
||||
LOG_BLOWUP_FACTOR,
|
||||
);
|
||||
let quot_poly_base_field =
|
||||
CpuCircleEvaluation::new(eval_domain, quot_eval.columns[0].clone()).interpolate();
|
||||
assert!(quot_poly_base_field.is_in_fri_space(LOG_SIZE));
|
||||
}
|
||||
}
|
||||
66
Stwo_wrapper/crates/prover/src/core/backend/mod.rs
Normal file
66
Stwo_wrapper/crates/prover/src/core/backend/mod.rs
Normal file
@ -0,0 +1,66 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
pub use cpu::CpuBackend;
|
||||
|
||||
use super::air::accumulation::AccumulationOps;
|
||||
use super::channel::MerkleChannel;
|
||||
use super::fields::m31::BaseField;
|
||||
use super::fields::qm31::SecureField;
|
||||
use super::fields::FieldOps;
|
||||
use super::fri::FriOps;
|
||||
use super::lookups::gkr_prover::GkrOps;
|
||||
use super::pcs::quotients::QuotientOps;
|
||||
use super::poly::circle::PolyOps;
|
||||
use super::proof_of_work::GrindOps;
|
||||
use super::vcs::ops::MerkleOps;
|
||||
|
||||
pub mod cpu;
|
||||
pub mod simd;
|
||||
|
||||
pub trait Backend:
|
||||
Copy
|
||||
+ Clone
|
||||
+ Debug
|
||||
+ FieldOps<BaseField>
|
||||
+ FieldOps<SecureField>
|
||||
+ PolyOps
|
||||
+ QuotientOps
|
||||
+ FriOps
|
||||
+ AccumulationOps
|
||||
+ GkrOps
|
||||
{
|
||||
}
|
||||
|
||||
pub trait BackendForChannel<MC: MerkleChannel>:
|
||||
Backend + MerkleOps<MC::H> + GrindOps<MC::C>
|
||||
{
|
||||
}
|
||||
|
||||
pub trait ColumnOps<T> {
|
||||
type Column: Column<T>;
|
||||
fn bit_reverse_column(column: &mut Self::Column);
|
||||
}
|
||||
|
||||
pub type Col<B, T> = <B as ColumnOps<T>>::Column;
|
||||
|
||||
// TODO(spapini): Consider removing the generic parameter and only support BaseField.
|
||||
pub trait Column<T>: Clone + Debug + FromIterator<T> {
|
||||
/// Creates a new column of zeros with the given length.
|
||||
fn zeros(len: usize) -> Self;
|
||||
/// Creates a new column of uninitialized values with the given length.
|
||||
/// # Safety
|
||||
/// The caller must ensure that the column is populated before being used.
|
||||
unsafe fn uninitialized(len: usize) -> Self;
|
||||
/// Returns a cpu vector of the column.
|
||||
fn to_cpu(&self) -> Vec<T>;
|
||||
/// Returns the length of the column.
|
||||
fn len(&self) -> usize;
|
||||
/// Returns true if the column is empty.
|
||||
fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
/// Retrieves the element at the given index.
|
||||
fn at(&self, index: usize) -> T;
|
||||
/// Sets the element at the given index.
|
||||
fn set(&mut self, index: usize, value: T);
|
||||
}
|
||||
@ -0,0 +1,12 @@
|
||||
use super::SimdBackend;
|
||||
use crate::core::air::accumulation::AccumulationOps;
|
||||
use crate::core::fields::secure_column::SecureColumnByCoords;
|
||||
|
||||
impl AccumulationOps for SimdBackend {
|
||||
fn accumulate(column: &mut SecureColumnByCoords<Self>, other: &SecureColumnByCoords<Self>) {
|
||||
for i in 0..column.packed_len() {
|
||||
let res_coeff = unsafe { column.packed_at(i) + other.packed_at(i) };
|
||||
unsafe { column.set_packed(i, res_coeff) };
|
||||
}
|
||||
}
|
||||
}
|
||||
203
Stwo_wrapper/crates/prover/src/core/backend/simd/bit_reverse.rs
Normal file
203
Stwo_wrapper/crates/prover/src/core/backend/simd/bit_reverse.rs
Normal file
@ -0,0 +1,203 @@
|
||||
use std::array;
|
||||
|
||||
use super::column::{BaseColumn, SecureColumn};
|
||||
use super::m31::PackedBaseField;
|
||||
use super::SimdBackend;
|
||||
use crate::core::backend::ColumnOps;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::utils::{bit_reverse as cpu_bit_reverse, bit_reverse_index};
|
||||
|
||||
const VEC_BITS: u32 = 4;
|
||||
|
||||
const W_BITS: u32 = 3;
|
||||
|
||||
pub const MIN_LOG_SIZE: u32 = 2 * W_BITS + VEC_BITS;
|
||||
|
||||
impl ColumnOps<BaseField> for SimdBackend {
|
||||
type Column = BaseColumn;
|
||||
|
||||
fn bit_reverse_column(column: &mut Self::Column) {
|
||||
// Fallback to cpu bit_reverse.
|
||||
if column.data.len().ilog2() < MIN_LOG_SIZE {
|
||||
cpu_bit_reverse(column.as_mut_slice());
|
||||
return;
|
||||
}
|
||||
|
||||
bit_reverse_m31(&mut column.data);
|
||||
}
|
||||
}
|
||||
|
||||
impl ColumnOps<SecureField> for SimdBackend {
|
||||
type Column = SecureColumn;
|
||||
|
||||
fn bit_reverse_column(_column: &mut SecureColumn) {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
/// Bit reverses M31 values.
|
||||
///
|
||||
/// Given an array `A[0..2^n)`, computes `B[i] = A[bit_reverse(i)]`.
|
||||
pub fn bit_reverse_m31(data: &mut [PackedBaseField]) {
|
||||
assert!(data.len().is_power_of_two());
|
||||
assert!(data.len().ilog2() >= MIN_LOG_SIZE);
|
||||
|
||||
// Indices in the array are of the form v_h w_h a w_l v_l, with
|
||||
// |v_h| = |v_l| = VEC_BITS, |w_h| = |w_l| = W_BITS, |a| = n - 2*W_BITS - VEC_BITS.
|
||||
// The loops go over a, w_l, w_h, and then swaps the 16 by 16 values at:
|
||||
// * w_h a w_l * <-> * rev(w_h a w_l) *.
|
||||
// These are 1 or 2 chunks of 2^W_BITS contiguous `u32x16` vectors.
|
||||
|
||||
let log_size = data.len().ilog2();
|
||||
let a_bits = log_size - 2 * W_BITS - VEC_BITS;
|
||||
|
||||
// TODO(spapini): when doing multithreading, do it over a.
|
||||
for a in 0u32..1 << a_bits {
|
||||
for w_l in 0u32..1 << W_BITS {
|
||||
let w_l_rev = w_l.reverse_bits() >> (u32::BITS - W_BITS);
|
||||
for w_h in 0..w_l_rev + 1 {
|
||||
let idx = ((((w_h << a_bits) | a) << W_BITS) | w_l) as usize;
|
||||
let idx_rev = bit_reverse_index(idx, log_size - VEC_BITS);
|
||||
|
||||
// In order to not swap twice, only swap if idx <= idx_rev.
|
||||
if idx > idx_rev {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Read first chunk.
|
||||
// TODO(spapini): Think about optimizing a_bits.
|
||||
let chunk0 = array::from_fn(|i| unsafe {
|
||||
*data.get_unchecked(idx + (i << (2 * W_BITS + a_bits)))
|
||||
});
|
||||
let values0 = bit_reverse16(chunk0);
|
||||
|
||||
if idx == idx_rev {
|
||||
// Palindrome index. Write into the same chunk.
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
for i in 0..16 {
|
||||
unsafe {
|
||||
*data.get_unchecked_mut(idx + (i << (2 * W_BITS + a_bits))) =
|
||||
values0[i];
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Read bit reversed chunk.
|
||||
let chunk1 = array::from_fn(|i| unsafe {
|
||||
*data.get_unchecked(idx_rev + (i << (2 * W_BITS + a_bits)))
|
||||
});
|
||||
let values1 = bit_reverse16(chunk1);
|
||||
|
||||
for i in 0..16 {
|
||||
unsafe {
|
||||
*data.get_unchecked_mut(idx + (i << (2 * W_BITS + a_bits))) = values1[i];
|
||||
*data.get_unchecked_mut(idx_rev + (i << (2 * W_BITS + a_bits))) =
|
||||
values0[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Bit reverses 256 M31 values, packed in 16 words of 16 elements each.
|
||||
fn bit_reverse16(mut data: [PackedBaseField; 16]) -> [PackedBaseField; 16] {
|
||||
// Denote the index of each element in the 16 packed M31 words as abcd:0123,
|
||||
// where abcd is the index of the packed word and 0123 is the index of the element in the word.
|
||||
// Bit reversal is achieved by applying the following permutation to the index for 4 times:
|
||||
// abcd:0123 => 0abc:123d
|
||||
// This is how it looks like at each iteration.
|
||||
// abcd:0123
|
||||
// 0abc:123d
|
||||
// 10ab:23dc
|
||||
// 210a:3dcb
|
||||
// 3210:dcba
|
||||
for _ in 0..4 {
|
||||
// Apply the abcd:0123 => 0abc:123d permutation.
|
||||
// `interleave` allows us to interleave the first half of 2 words.
|
||||
// For example, the second call interleaves 0010:0xyz (low half of register 2) with
|
||||
// 0011:0xyz (low half of register 3), and stores the result in register 1 (0001).
|
||||
// This results in
|
||||
// 0001:xyz0 (even indices of register 1) <= 0010:0xyz (low half of register2), and
|
||||
// 0001:xyz1 (odd indices of register 1) <= 0011:0xyz (low half of register 3)
|
||||
// or 0001:xyzw <= 001w:0xyz.
|
||||
let (d0, d8) = data[0].interleave(data[1]);
|
||||
let (d1, d9) = data[2].interleave(data[3]);
|
||||
let (d2, d10) = data[4].interleave(data[5]);
|
||||
let (d3, d11) = data[6].interleave(data[7]);
|
||||
let (d4, d12) = data[8].interleave(data[9]);
|
||||
let (d5, d13) = data[10].interleave(data[11]);
|
||||
let (d6, d14) = data[12].interleave(data[13]);
|
||||
let (d7, d15) = data[14].interleave(data[15]);
|
||||
data = [
|
||||
d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, d14, d15,
|
||||
];
|
||||
}
|
||||
|
||||
data
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use itertools::Itertools;
|
||||
|
||||
use super::{bit_reverse16, bit_reverse_m31, MIN_LOG_SIZE};
|
||||
use crate::core::backend::simd::column::BaseColumn;
|
||||
use crate::core::backend::simd::m31::{PackedM31, N_LANES};
|
||||
use crate::core::backend::simd::SimdBackend;
|
||||
use crate::core::backend::{Column, ColumnOps};
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::utils::bit_reverse as cpu_bit_reverse;
|
||||
|
||||
#[test]
|
||||
fn test_bit_reverse16() {
|
||||
let values: BaseColumn = (0..N_LANES * 16).map(BaseField::from).collect();
|
||||
let mut expected = values.to_cpu();
|
||||
cpu_bit_reverse(&mut expected);
|
||||
|
||||
let res = bit_reverse16(values.data.try_into().unwrap());
|
||||
|
||||
assert_eq!(res.map(PackedM31::to_array).flatten(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bit_reverse_m31_works() {
|
||||
const SIZE: usize = 1 << 15;
|
||||
let data: Vec<_> = (0..SIZE).map(BaseField::from).collect();
|
||||
let mut expected = data.clone();
|
||||
cpu_bit_reverse(&mut expected);
|
||||
|
||||
let mut res: BaseColumn = data.into_iter().collect();
|
||||
bit_reverse_m31(&mut res.data[..]);
|
||||
|
||||
assert_eq!(res.to_cpu(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bit_reverse_small_column_works() {
|
||||
const LOG_SIZE: u32 = MIN_LOG_SIZE - 1;
|
||||
let column = (0..1 << LOG_SIZE).map(BaseField::from).collect_vec();
|
||||
let mut expected = column.clone();
|
||||
cpu_bit_reverse(&mut expected);
|
||||
|
||||
let mut res = column.iter().copied().collect::<BaseColumn>();
|
||||
<SimdBackend as ColumnOps<BaseField>>::bit_reverse_column(&mut res);
|
||||
|
||||
assert_eq!(res.to_cpu(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bit_reverse_large_column_works() {
|
||||
const LOG_SIZE: u32 = MIN_LOG_SIZE;
|
||||
let column = (0..1 << LOG_SIZE).map(BaseField::from).collect_vec();
|
||||
let mut expected = column.clone();
|
||||
cpu_bit_reverse(&mut expected);
|
||||
|
||||
let mut res = column.iter().copied().collect::<BaseColumn>();
|
||||
<SimdBackend as ColumnOps<BaseField>>::bit_reverse_column(&mut res);
|
||||
|
||||
assert_eq!(res.to_cpu(), expected);
|
||||
}
|
||||
}
|
||||
412
Stwo_wrapper/crates/prover/src/core/backend/simd/blake2s.rs
Normal file
412
Stwo_wrapper/crates/prover/src/core/backend/simd/blake2s.rs
Normal file
@ -0,0 +1,412 @@
|
||||
//! A SIMD implementation of the BLAKE2s compression function.
|
||||
//! Based on <https://github.com/oconnor663/blake2_simd/blob/master/blake2s/src/avx2.rs>.
|
||||
|
||||
use std::array;
|
||||
use std::iter::repeat;
|
||||
use std::mem::transmute;
|
||||
use std::simd::u32x16;
|
||||
|
||||
use bytemuck::cast_slice;
|
||||
use itertools::Itertools;
|
||||
#[cfg(feature = "parallel")]
|
||||
use rayon::prelude::*;
|
||||
|
||||
use super::m31::{LOG_N_LANES, N_LANES};
|
||||
use super::SimdBackend;
|
||||
use crate::core::backend::{Col, Column, ColumnOps};
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::vcs::blake2_hash::Blake2sHash;
|
||||
use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher;
|
||||
use crate::core::vcs::ops::{MerkleHasher, MerkleOps};
|
||||
|
||||
const IV: [u32; 8] = [
|
||||
0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19,
|
||||
];
|
||||
|
||||
pub const SIGMA: [[u8; 16]; 10] = [
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
|
||||
[14, 10, 4, 8, 9, 15, 13, 6, 1, 12, 0, 2, 11, 7, 5, 3],
|
||||
[11, 8, 12, 0, 5, 2, 15, 13, 10, 14, 3, 6, 7, 1, 9, 4],
|
||||
[7, 9, 3, 1, 13, 12, 11, 14, 2, 6, 5, 10, 4, 0, 15, 8],
|
||||
[9, 0, 5, 7, 2, 4, 10, 15, 14, 1, 11, 12, 6, 8, 3, 13],
|
||||
[2, 12, 6, 10, 0, 11, 8, 3, 4, 13, 7, 5, 15, 14, 1, 9],
|
||||
[12, 5, 1, 15, 14, 13, 4, 10, 0, 7, 6, 3, 9, 2, 8, 11],
|
||||
[13, 11, 7, 14, 12, 1, 3, 9, 5, 0, 15, 4, 8, 6, 2, 10],
|
||||
[6, 15, 14, 9, 11, 3, 0, 8, 12, 2, 13, 7, 1, 4, 10, 5],
|
||||
[10, 2, 8, 4, 7, 6, 1, 5, 15, 11, 9, 14, 3, 12, 13, 0],
|
||||
];
|
||||
|
||||
impl ColumnOps<Blake2sHash> for SimdBackend {
|
||||
type Column = Vec<Blake2sHash>;
|
||||
|
||||
fn bit_reverse_column(_column: &mut Self::Column) {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
impl MerkleOps<Blake2sMerkleHasher> for SimdBackend {
|
||||
fn commit_on_layer(
|
||||
log_size: u32,
|
||||
prev_layer: Option<&Vec<Blake2sHash>>,
|
||||
columns: &[&Col<Self, BaseField>],
|
||||
) -> Vec<Blake2sHash> {
|
||||
if log_size < LOG_N_LANES {
|
||||
#[cfg(not(feature = "parallel"))]
|
||||
let iter = 0..1 << log_size;
|
||||
|
||||
#[cfg(feature = "parallel")]
|
||||
let iter = (0..1 << log_size).into_par_iter();
|
||||
|
||||
return iter
|
||||
.map(|i| {
|
||||
Blake2sMerkleHasher::hash_node(
|
||||
prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])),
|
||||
&columns.iter().map(|column| column.at(i)).collect_vec(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
|
||||
if let Some(prev_layer) = prev_layer {
|
||||
assert_eq!(prev_layer.len(), 1 << (log_size + 1));
|
||||
}
|
||||
|
||||
let zeros = u32x16::splat(0);
|
||||
|
||||
// Commit to columns.
|
||||
let mut res = vec![Blake2sHash::default(); 1 << log_size];
|
||||
#[cfg(not(feature = "parallel"))]
|
||||
let iter = res.chunks_mut(1 << LOG_N_LANES);
|
||||
|
||||
#[cfg(feature = "parallel")]
|
||||
let iter = res.par_chunks_mut(1 << LOG_N_LANES);
|
||||
|
||||
iter.enumerate().for_each(|(i, chunk)| {
|
||||
let mut state: [u32x16; 8] = unsafe { std::mem::zeroed() };
|
||||
// Hash prev_layer, if exists.
|
||||
if let Some(prev_layer) = prev_layer {
|
||||
let prev_chunk_u32s = cast_slice::<_, u32>(&prev_layer[(i << 5)..((i + 1) << 5)]);
|
||||
// Note: prev_layer might be unaligned.
|
||||
let msgs: [u32x16; 16] = array::from_fn(|j| {
|
||||
u32x16::from_array(std::array::from_fn(|k| prev_chunk_u32s[16 * j + k]))
|
||||
});
|
||||
state = compress16(state, transpose_msgs(msgs), zeros, zeros, zeros, zeros);
|
||||
}
|
||||
|
||||
// Hash columns in chunks of 16.
|
||||
let mut col_chunk_iter = columns.array_chunks();
|
||||
for col_chunk in &mut col_chunk_iter {
|
||||
let msgs = col_chunk.map(|column| column.data[i].into_simd());
|
||||
state = compress16(state, msgs, zeros, zeros, zeros, zeros);
|
||||
}
|
||||
|
||||
// Hash remaining columns.
|
||||
let remainder = col_chunk_iter.remainder();
|
||||
if !remainder.is_empty() {
|
||||
let msgs = remainder
|
||||
.iter()
|
||||
.map(|column| column.data[i].into_simd())
|
||||
.chain(repeat(zeros))
|
||||
.take(N_LANES)
|
||||
.collect_vec()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
state = compress16(state, msgs, zeros, zeros, zeros, zeros);
|
||||
}
|
||||
let state: [Blake2sHash; 16] = unsafe { transmute(untranspose_states(state)) };
|
||||
chunk.copy_from_slice(&state);
|
||||
});
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies [`u32::rotate_right(N)`] to each element of the vector
|
||||
///
|
||||
/// [`u32::rotate_right(N)`]: u32::rotate_right
|
||||
#[inline(always)]
|
||||
fn rotate<const N: u32>(x: u32x16) -> u32x16 {
|
||||
(x >> N) | (x << (u32::BITS - N))
|
||||
}
|
||||
|
||||
// `inline(always)` can cause code parsing errors for wasm: "locals exceed maximum".
|
||||
#[cfg_attr(not(target_arch = "wasm32"), inline(always))]
|
||||
pub fn round(v: &mut [u32x16; 16], m: [u32x16; 16], r: usize) {
|
||||
v[0] += m[SIGMA[r][0] as usize];
|
||||
v[1] += m[SIGMA[r][2] as usize];
|
||||
v[2] += m[SIGMA[r][4] as usize];
|
||||
v[3] += m[SIGMA[r][6] as usize];
|
||||
v[0] += v[4];
|
||||
v[1] += v[5];
|
||||
v[2] += v[6];
|
||||
v[3] += v[7];
|
||||
v[12] ^= v[0];
|
||||
v[13] ^= v[1];
|
||||
v[14] ^= v[2];
|
||||
v[15] ^= v[3];
|
||||
v[12] = rotate::<16>(v[12]);
|
||||
v[13] = rotate::<16>(v[13]);
|
||||
v[14] = rotate::<16>(v[14]);
|
||||
v[15] = rotate::<16>(v[15]);
|
||||
v[8] += v[12];
|
||||
v[9] += v[13];
|
||||
v[10] += v[14];
|
||||
v[11] += v[15];
|
||||
v[4] ^= v[8];
|
||||
v[5] ^= v[9];
|
||||
v[6] ^= v[10];
|
||||
v[7] ^= v[11];
|
||||
v[4] = rotate::<12>(v[4]);
|
||||
v[5] = rotate::<12>(v[5]);
|
||||
v[6] = rotate::<12>(v[6]);
|
||||
v[7] = rotate::<12>(v[7]);
|
||||
v[0] += m[SIGMA[r][1] as usize];
|
||||
v[1] += m[SIGMA[r][3] as usize];
|
||||
v[2] += m[SIGMA[r][5] as usize];
|
||||
v[3] += m[SIGMA[r][7] as usize];
|
||||
v[0] += v[4];
|
||||
v[1] += v[5];
|
||||
v[2] += v[6];
|
||||
v[3] += v[7];
|
||||
v[12] ^= v[0];
|
||||
v[13] ^= v[1];
|
||||
v[14] ^= v[2];
|
||||
v[15] ^= v[3];
|
||||
v[12] = rotate::<8>(v[12]);
|
||||
v[13] = rotate::<8>(v[13]);
|
||||
v[14] = rotate::<8>(v[14]);
|
||||
v[15] = rotate::<8>(v[15]);
|
||||
v[8] += v[12];
|
||||
v[9] += v[13];
|
||||
v[10] += v[14];
|
||||
v[11] += v[15];
|
||||
v[4] ^= v[8];
|
||||
v[5] ^= v[9];
|
||||
v[6] ^= v[10];
|
||||
v[7] ^= v[11];
|
||||
v[4] = rotate::<7>(v[4]);
|
||||
v[5] = rotate::<7>(v[5]);
|
||||
v[6] = rotate::<7>(v[6]);
|
||||
v[7] = rotate::<7>(v[7]);
|
||||
|
||||
v[0] += m[SIGMA[r][8] as usize];
|
||||
v[1] += m[SIGMA[r][10] as usize];
|
||||
v[2] += m[SIGMA[r][12] as usize];
|
||||
v[3] += m[SIGMA[r][14] as usize];
|
||||
v[0] += v[5];
|
||||
v[1] += v[6];
|
||||
v[2] += v[7];
|
||||
v[3] += v[4];
|
||||
v[15] ^= v[0];
|
||||
v[12] ^= v[1];
|
||||
v[13] ^= v[2];
|
||||
v[14] ^= v[3];
|
||||
v[15] = rotate::<16>(v[15]);
|
||||
v[12] = rotate::<16>(v[12]);
|
||||
v[13] = rotate::<16>(v[13]);
|
||||
v[14] = rotate::<16>(v[14]);
|
||||
v[10] += v[15];
|
||||
v[11] += v[12];
|
||||
v[8] += v[13];
|
||||
v[9] += v[14];
|
||||
v[5] ^= v[10];
|
||||
v[6] ^= v[11];
|
||||
v[7] ^= v[8];
|
||||
v[4] ^= v[9];
|
||||
v[5] = rotate::<12>(v[5]);
|
||||
v[6] = rotate::<12>(v[6]);
|
||||
v[7] = rotate::<12>(v[7]);
|
||||
v[4] = rotate::<12>(v[4]);
|
||||
v[0] += m[SIGMA[r][9] as usize];
|
||||
v[1] += m[SIGMA[r][11] as usize];
|
||||
v[2] += m[SIGMA[r][13] as usize];
|
||||
v[3] += m[SIGMA[r][15] as usize];
|
||||
v[0] += v[5];
|
||||
v[1] += v[6];
|
||||
v[2] += v[7];
|
||||
v[3] += v[4];
|
||||
v[15] ^= v[0];
|
||||
v[12] ^= v[1];
|
||||
v[13] ^= v[2];
|
||||
v[14] ^= v[3];
|
||||
v[15] = rotate::<8>(v[15]);
|
||||
v[12] = rotate::<8>(v[12]);
|
||||
v[13] = rotate::<8>(v[13]);
|
||||
v[14] = rotate::<8>(v[14]);
|
||||
v[10] += v[15];
|
||||
v[11] += v[12];
|
||||
v[8] += v[13];
|
||||
v[9] += v[14];
|
||||
v[5] ^= v[10];
|
||||
v[6] ^= v[11];
|
||||
v[7] ^= v[8];
|
||||
v[4] ^= v[9];
|
||||
v[5] = rotate::<7>(v[5]);
|
||||
v[6] = rotate::<7>(v[6]);
|
||||
v[7] = rotate::<7>(v[7]);
|
||||
v[4] = rotate::<7>(v[4]);
|
||||
}
|
||||
|
||||
/// Transposes input chunks (16 chunks of 16 `u32`s each), to get 16 `u32x16`, each
|
||||
/// representing 16 packed instances of a message word.
|
||||
fn transpose_msgs(mut data: [u32x16; 16]) -> [u32x16; 16] {
|
||||
// Index abcd:xyzw, refers to a specific word in data as follows:
|
||||
// abcd - chunk index (in base 2)
|
||||
// xyzw - word offset (in base 2)
|
||||
// Transpose by applying 4 times the index permutation:
|
||||
// abcd:xyzw => wabc:dxyz
|
||||
// In other words, rotate the index to the right by 1.
|
||||
for _ in 0..4 {
|
||||
let (d0, d8) = data[0].deinterleave(data[1]);
|
||||
let (d1, d9) = data[2].deinterleave(data[3]);
|
||||
let (d2, d10) = data[4].deinterleave(data[5]);
|
||||
let (d3, d11) = data[6].deinterleave(data[7]);
|
||||
let (d4, d12) = data[8].deinterleave(data[9]);
|
||||
let (d5, d13) = data[10].deinterleave(data[11]);
|
||||
let (d6, d14) = data[12].deinterleave(data[13]);
|
||||
let (d7, d15) = data[14].deinterleave(data[15]);
|
||||
data = [
|
||||
d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, d14, d15,
|
||||
];
|
||||
}
|
||||
|
||||
data
|
||||
}
|
||||
|
||||
fn untranspose_states(mut states: [u32x16; 8]) -> [u32x16; 8] {
|
||||
// Index abc:xyzw, refers to a specific word in data as follows:
|
||||
// abc - chunk index (in base 2)
|
||||
// xyzw - word offset (in base 2)
|
||||
// Transpose by applying 3 times the index permutation:
|
||||
// abc:xyzw => bcx:yzwa
|
||||
// In other words, rotate the index to the left by 1.
|
||||
for _ in 0..3 {
|
||||
let (d0, d1) = states[0].interleave(states[4]);
|
||||
let (d2, d3) = states[1].interleave(states[5]);
|
||||
let (d4, d5) = states[2].interleave(states[6]);
|
||||
let (d6, d7) = states[3].interleave(states[7]);
|
||||
states = [d0, d1, d2, d3, d4, d5, d6, d7];
|
||||
}
|
||||
states
|
||||
}
|
||||
|
||||
/// Compresses 16 blake2s instances.
|
||||
pub fn compress16(
|
||||
h_vecs: [u32x16; 8],
|
||||
msg_vecs: [u32x16; 16],
|
||||
count_low: u32x16,
|
||||
count_high: u32x16,
|
||||
lastblock: u32x16,
|
||||
lastnode: u32x16,
|
||||
) -> [u32x16; 8] {
|
||||
let mut v = [
|
||||
h_vecs[0],
|
||||
h_vecs[1],
|
||||
h_vecs[2],
|
||||
h_vecs[3],
|
||||
h_vecs[4],
|
||||
h_vecs[5],
|
||||
h_vecs[6],
|
||||
h_vecs[7],
|
||||
u32x16::splat(IV[0]),
|
||||
u32x16::splat(IV[1]),
|
||||
u32x16::splat(IV[2]),
|
||||
u32x16::splat(IV[3]),
|
||||
u32x16::splat(IV[4]) ^ count_low,
|
||||
u32x16::splat(IV[5]) ^ count_high,
|
||||
u32x16::splat(IV[6]) ^ lastblock,
|
||||
u32x16::splat(IV[7]) ^ lastnode,
|
||||
];
|
||||
|
||||
round(&mut v, msg_vecs, 0);
|
||||
round(&mut v, msg_vecs, 1);
|
||||
round(&mut v, msg_vecs, 2);
|
||||
round(&mut v, msg_vecs, 3);
|
||||
round(&mut v, msg_vecs, 4);
|
||||
round(&mut v, msg_vecs, 5);
|
||||
round(&mut v, msg_vecs, 6);
|
||||
round(&mut v, msg_vecs, 7);
|
||||
round(&mut v, msg_vecs, 8);
|
||||
round(&mut v, msg_vecs, 9);
|
||||
|
||||
[
|
||||
h_vecs[0] ^ v[0] ^ v[8],
|
||||
h_vecs[1] ^ v[1] ^ v[9],
|
||||
h_vecs[2] ^ v[2] ^ v[10],
|
||||
h_vecs[3] ^ v[3] ^ v[11],
|
||||
h_vecs[4] ^ v[4] ^ v[12],
|
||||
h_vecs[5] ^ v[5] ^ v[13],
|
||||
h_vecs[6] ^ v[6] ^ v[14],
|
||||
h_vecs[7] ^ v[7] ^ v[15],
|
||||
]
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::array;
|
||||
use std::mem::transmute;
|
||||
use std::simd::u32x16;
|
||||
|
||||
use aligned::{Aligned, A64};
|
||||
|
||||
use super::{compress16, transpose_msgs, untranspose_states};
|
||||
use crate::core::vcs::blake2s_ref::compress;
|
||||
|
||||
#[test]
|
||||
fn compress16_works() {
|
||||
let states: Aligned<A64, [[u32; 8]; 16]> =
|
||||
Aligned(array::from_fn(|i| array::from_fn(|j| (i + j) as u32)));
|
||||
let msgs: Aligned<A64, [[u32; 16]; 16]> =
|
||||
Aligned(array::from_fn(|i| array::from_fn(|j| (i + j + 20) as u32)));
|
||||
let count_low = 1;
|
||||
let count_high = 2;
|
||||
let lastblock = 3;
|
||||
let lastnode = 4;
|
||||
let res_unvectorized = array::from_fn(|i| {
|
||||
compress(
|
||||
states[i], msgs[i], count_low, count_high, lastblock, lastnode,
|
||||
)
|
||||
});
|
||||
|
||||
let res_vectorized: [[u32; 8]; 16] = unsafe {
|
||||
transmute(untranspose_states(compress16(
|
||||
transpose_states(transmute(states)),
|
||||
transpose_msgs(transmute(msgs)),
|
||||
u32x16::splat(count_low),
|
||||
u32x16::splat(count_high),
|
||||
u32x16::splat(lastblock),
|
||||
u32x16::splat(lastnode),
|
||||
)))
|
||||
};
|
||||
|
||||
assert_eq!(res_vectorized, res_unvectorized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn untranspose_states_is_transpose_states_inverse() {
|
||||
let states = array::from_fn(|i| u32x16::from(array::from_fn(|j| (i + j) as u32)));
|
||||
let transposed_states = transpose_states(states);
|
||||
|
||||
let untrasponsed_transposed_states = untranspose_states(transposed_states);
|
||||
|
||||
assert_eq!(untrasponsed_transposed_states, states)
|
||||
}
|
||||
|
||||
/// Transposes states, from 8 packed words, to get 16 results, each of size 32B.
|
||||
fn transpose_states(mut states: [u32x16; 8]) -> [u32x16; 8] {
|
||||
// Index abc:xyzw, refers to a specific word in data as follows:
|
||||
// abc - chunk index (in base 2)
|
||||
// xyzw - word offset (in base 2)
|
||||
// Transpose by applying 3 times the index permutation:
|
||||
// abc:xyzw => wab:cxyz
|
||||
// In other words, rotate the index to the right by 1.
|
||||
for _ in 0..3 {
|
||||
let (s0, s4) = states[0].deinterleave(states[1]);
|
||||
let (s1, s5) = states[2].deinterleave(states[3]);
|
||||
let (s2, s6) = states[4].deinterleave(states[5]);
|
||||
let (s3, s7) = states[6].deinterleave(states[7]);
|
||||
states = [s0, s1, s2, s3, s4, s5, s6, s7];
|
||||
}
|
||||
|
||||
states
|
||||
}
|
||||
}
|
||||
436
Stwo_wrapper/crates/prover/src/core/backend/simd/circle.rs
Normal file
436
Stwo_wrapper/crates/prover/src/core/backend/simd/circle.rs
Normal file
@ -0,0 +1,436 @@
|
||||
use std::iter::zip;
|
||||
use std::mem::transmute;
|
||||
|
||||
use bytemuck::{cast_slice, Zeroable};
|
||||
use num_traits::One;
|
||||
|
||||
use super::fft::{ifft, rfft, CACHED_FFT_LOG_SIZE};
|
||||
use super::m31::{PackedBaseField, LOG_N_LANES, N_LANES};
|
||||
use super::qm31::PackedSecureField;
|
||||
use super::SimdBackend;
|
||||
use crate::core::backend::simd::column::BaseColumn;
|
||||
use crate::core::backend::{Col, CpuBackend};
|
||||
use crate::core::circle::{CirclePoint, Coset};
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::fields::{Field, FieldExpOps};
|
||||
use crate::core::poly::circle::{
|
||||
CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly, PolyOps,
|
||||
};
|
||||
use crate::core::poly::twiddles::TwiddleTree;
|
||||
use crate::core::poly::utils::{domain_line_twiddles_from_tree, fold};
|
||||
use crate::core::poly::BitReversedOrder;
|
||||
|
||||
impl SimdBackend {
|
||||
// TODO(Ohad): optimize.
|
||||
fn twiddle_at<F: Field>(mappings: &[F], mut index: usize) -> F {
|
||||
debug_assert!(
|
||||
(1 << mappings.len()) as usize >= index,
|
||||
"Index out of bounds. mappings log len = {}, index = {index}",
|
||||
mappings.len().ilog2()
|
||||
);
|
||||
|
||||
let mut product = F::one();
|
||||
for &num in mappings.iter() {
|
||||
if index & 1 == 1 {
|
||||
product *= num;
|
||||
}
|
||||
index >>= 1;
|
||||
if index == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
product
|
||||
}
|
||||
|
||||
// TODO(Ohad): consider moving this to to a more general place.
|
||||
// Note: CACHED_FFT_LOG_SIZE is specific to the backend.
|
||||
fn generate_evaluation_mappings<F: Field>(point: CirclePoint<F>, log_size: u32) -> Vec<F> {
|
||||
// Mappings are the factors used to compute the evaluation twiddle.
|
||||
// Every twiddle (i) is of the form (m[0])^b_0 * (m[1])^b_1 * ... * (m[log_size -
|
||||
// 1])^b_log_size.
|
||||
// Where (m)_j are the mappings, and b_i is the j'th bit of i.
|
||||
let mut mappings = vec![point.y, point.x];
|
||||
let mut x = point.x;
|
||||
for _ in 2..log_size {
|
||||
x = CirclePoint::double_x(x);
|
||||
mappings.push(x);
|
||||
}
|
||||
|
||||
// The caller function expects the mapping in natural order. i.e. (y,x,h(x),h(h(x)),...).
|
||||
// If the polynomial is large, the fft does a transpose in the middle in a granularity of 16
|
||||
// (avx512). The coefficients would then be in tranposed order of 16-sized chunks.
|
||||
// i.e. (a_(n-15), a_(n-14), ..., a_(n-1), a_(n-31), ..., a_(n-16), a_(n-32), ...).
|
||||
// To compute the twiddles in the correct order, we need to transpose the coprresponding
|
||||
// 'transposed bits' in the mappings. The result order of the mappings would then be
|
||||
// (y, x, h(x), h^2(x), h^(log_n-1)(x), h^(log_n-2)(x) ...). To avoid code
|
||||
// complexity for now, we just reverse the mappings, transpose, then reverse back.
|
||||
// TODO(Ohad): optimize. consider changing the caller to expect the mappings in
|
||||
// reversed-tranposed order.
|
||||
if log_size > CACHED_FFT_LOG_SIZE {
|
||||
mappings.reverse();
|
||||
let n = mappings.len();
|
||||
let n0 = (n - LOG_N_LANES as usize) / 2;
|
||||
let n1 = (n - LOG_N_LANES as usize + 1) / 2;
|
||||
let (ab, c) = mappings.split_at_mut(n1);
|
||||
let (a, _b) = ab.split_at_mut(n0);
|
||||
// Swap content of a,c.
|
||||
a.swap_with_slice(&mut c[0..n0]);
|
||||
mappings.reverse();
|
||||
}
|
||||
|
||||
mappings
|
||||
}
|
||||
|
||||
// Generates twiddle steps for efficiently computing the twiddles.
|
||||
// steps[i] = t_i/(t_0*t_1*...*t_i-1).
|
||||
fn twiddle_steps<F: Field>(mappings: &[F]) -> Vec<F>
|
||||
where
|
||||
F: FieldExpOps,
|
||||
{
|
||||
let mut denominators: Vec<F> = vec![mappings[0]];
|
||||
|
||||
for i in 1..mappings.len() {
|
||||
denominators.push(denominators[i - 1] * mappings[i]);
|
||||
}
|
||||
|
||||
let mut denom_inverses = vec![F::zero(); denominators.len()];
|
||||
F::batch_inverse(&denominators, &mut denom_inverses);
|
||||
|
||||
let mut steps = vec![mappings[0]];
|
||||
|
||||
mappings
|
||||
.iter()
|
||||
.skip(1)
|
||||
.zip(denom_inverses.iter())
|
||||
.for_each(|(&m, &d)| {
|
||||
steps.push(m * d);
|
||||
});
|
||||
steps.push(F::one());
|
||||
steps
|
||||
}
|
||||
|
||||
// Advances the twiddle by multiplying it by the next step. e.g:
|
||||
// If idx(t) = 0b100..1010 , then f(t) = t * step[0]
|
||||
// If idx(t) = 0b100..0111 , then f(t) = t * step[3]
|
||||
fn advance_twiddle<F: Field>(twiddle: F, steps: &[F], curr_idx: usize) -> F {
|
||||
twiddle * steps[curr_idx.trailing_ones() as usize]
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(spapini): Everything is returned in redundant representation, where values can also be P.
|
||||
// Decide if and when it's ok and what to do if it's not.
|
||||
impl PolyOps for SimdBackend {
|
||||
// The twiddles type is i32, and not BaseField. This is because the fast AVX mul implementation
|
||||
// requries one of the numbers to be shifted left by 1 bit. This is not a reduced
|
||||
// representation of the field.
|
||||
type Twiddles = Vec<u32>;
|
||||
|
||||
fn new_canonical_ordered(
|
||||
coset: CanonicCoset,
|
||||
values: Col<Self, BaseField>,
|
||||
) -> CircleEvaluation<Self, BaseField, BitReversedOrder> {
|
||||
// TODO(spapini): Optimize.
|
||||
let eval = CpuBackend::new_canonical_ordered(coset, values.into_cpu_vec());
|
||||
CircleEvaluation::new(
|
||||
eval.domain,
|
||||
Col::<SimdBackend, BaseField>::from_iter(eval.values),
|
||||
)
|
||||
}
|
||||
|
||||
fn interpolate(
|
||||
eval: CircleEvaluation<Self, BaseField, BitReversedOrder>,
|
||||
twiddles: &TwiddleTree<Self>,
|
||||
) -> CirclePoly<Self> {
|
||||
let mut values = eval.values;
|
||||
let log_size = values.length.ilog2();
|
||||
|
||||
let twiddles = domain_line_twiddles_from_tree(eval.domain, &twiddles.itwiddles);
|
||||
|
||||
// Safe because [PackedBaseField] is aligned on 64 bytes.
|
||||
unsafe {
|
||||
ifft::ifft(
|
||||
transmute(values.data.as_mut_ptr()),
|
||||
&twiddles,
|
||||
log_size as usize,
|
||||
);
|
||||
}
|
||||
|
||||
// TODO(spapini): Fuse this multiplication / rotation.
|
||||
let inv = PackedBaseField::broadcast(BaseField::from(eval.domain.size()).inverse());
|
||||
values.data.iter_mut().for_each(|x| *x *= inv);
|
||||
|
||||
CirclePoly::new(values)
|
||||
}
|
||||
|
||||
fn eval_at_point(poly: &CirclePoly<Self>, point: CirclePoint<SecureField>) -> SecureField {
|
||||
// If the polynomial is small, fallback to evaluate directly.
|
||||
// TODO(Ohad): it's possible to avoid falling back. Consider fixing.
|
||||
if poly.log_size() <= 8 {
|
||||
return slow_eval_at_point(poly, point);
|
||||
}
|
||||
|
||||
let mappings = Self::generate_evaluation_mappings(point, poly.log_size());
|
||||
|
||||
// 8 lowest mappings produce the first 2^8 twiddles. Separate to optimize each calculation.
|
||||
let (map_low, map_high) = mappings.split_at(4);
|
||||
let twiddle_lows =
|
||||
PackedSecureField::from_array(std::array::from_fn(|i| Self::twiddle_at(map_low, i)));
|
||||
let (map_mid, map_high) = map_high.split_at(4);
|
||||
let twiddle_mids =
|
||||
PackedSecureField::from_array(std::array::from_fn(|i| Self::twiddle_at(map_mid, i)));
|
||||
|
||||
// Compute the high twiddle steps.
|
||||
let twiddle_steps = Self::twiddle_steps(map_high);
|
||||
|
||||
// Every twiddle is a product of mappings that correspond to '1's in the bit representation
|
||||
// of the current index. For every 2^n alligned chunk of 2^n elements, the twiddle
|
||||
// array is the same, denoted twiddle_low. Use this to compute sums of (coeff *
|
||||
// twiddle_high) mod 2^n, then multiply by twiddle_low, and sum to get the final result.
|
||||
let mut sum = PackedSecureField::zeroed();
|
||||
let mut twiddle_high = SecureField::one();
|
||||
for (i, coeff_chunk) in poly.coeffs.data.array_chunks::<N_LANES>().enumerate() {
|
||||
// For every chunk of 2 ^ 4 * 2 ^ 4 = 2 ^ 8 elements, the twiddle high is the same.
|
||||
// Multiply it by every mid twiddle factor to get the factors for the current chunk.
|
||||
let high_twiddle_factors =
|
||||
(PackedSecureField::broadcast(twiddle_high) * twiddle_mids).to_array();
|
||||
|
||||
// Sum the coefficients multiplied by each corrseponsing twiddle. Result is effectivley
|
||||
// an array[16] where the value at index 'i' is the sum of all coefficients at indices
|
||||
// that are i mod 16.
|
||||
for (&packed_coeffs, mid_twiddle) in zip(coeff_chunk, high_twiddle_factors) {
|
||||
sum += PackedSecureField::broadcast(mid_twiddle) * packed_coeffs;
|
||||
}
|
||||
|
||||
// Advance twiddle high.
|
||||
twiddle_high = Self::advance_twiddle(twiddle_high, &twiddle_steps, i);
|
||||
}
|
||||
|
||||
(sum * twiddle_lows).pointwise_sum()
|
||||
}
|
||||
|
||||
fn extend(poly: &CirclePoly<Self>, log_size: u32) -> CirclePoly<Self> {
|
||||
// TODO(spapini): Optimize or get rid of extend.
|
||||
poly.evaluate(CanonicCoset::new(log_size).circle_domain())
|
||||
.interpolate()
|
||||
}
|
||||
|
||||
fn evaluate(
|
||||
poly: &CirclePoly<Self>,
|
||||
domain: CircleDomain,
|
||||
twiddles: &TwiddleTree<Self>,
|
||||
) -> CircleEvaluation<Self, BaseField, BitReversedOrder> {
|
||||
// TODO(spapini): Precompute twiddles.
|
||||
// TODO(spapini): Handle small cases.
|
||||
let log_size = domain.log_size();
|
||||
let fft_log_size = poly.log_size();
|
||||
assert!(
|
||||
log_size >= fft_log_size,
|
||||
"Can only evaluate on larger domains"
|
||||
);
|
||||
|
||||
let twiddles = domain_line_twiddles_from_tree(domain, &twiddles.twiddles);
|
||||
|
||||
// Evaluate on a big domains by evaluating on several subdomains.
|
||||
let log_subdomains = log_size - fft_log_size;
|
||||
|
||||
// Allocate the destination buffer without initializing.
|
||||
let mut values = Vec::with_capacity(domain.size() >> LOG_N_LANES);
|
||||
#[allow(clippy::uninit_vec)]
|
||||
unsafe {
|
||||
values.set_len(domain.size() >> LOG_N_LANES)
|
||||
};
|
||||
|
||||
for i in 0..(1 << log_subdomains) {
|
||||
// The subdomain twiddles are a slice of the large domain twiddles.
|
||||
let subdomain_twiddles = (0..(fft_log_size - 1))
|
||||
.map(|layer_i| {
|
||||
&twiddles[layer_i as usize]
|
||||
[i << (fft_log_size - 2 - layer_i)..(i + 1) << (fft_log_size - 2 - layer_i)]
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// FFT from the coefficients buffer to the values chunk.
|
||||
unsafe {
|
||||
rfft::fft(
|
||||
transmute(poly.coeffs.data.as_ptr()),
|
||||
transmute(
|
||||
values[i << (fft_log_size - LOG_N_LANES)
|
||||
..(i + 1) << (fft_log_size - LOG_N_LANES)]
|
||||
.as_mut_ptr(),
|
||||
),
|
||||
&subdomain_twiddles,
|
||||
fft_log_size as usize,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
CircleEvaluation::new(
|
||||
domain,
|
||||
BaseColumn {
|
||||
data: values,
|
||||
length: domain.size(),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
fn precompute_twiddles(coset: Coset) -> TwiddleTree<Self> {
|
||||
let mut twiddles = Vec::with_capacity(coset.size());
|
||||
let mut itwiddles = Vec::with_capacity(coset.size());
|
||||
|
||||
// TODO(spapini): Optimize.
|
||||
for layer in &rfft::get_twiddle_dbls(coset) {
|
||||
twiddles.extend(layer);
|
||||
}
|
||||
// Pad by any value, to make the size a power of 2.
|
||||
twiddles.push(1);
|
||||
assert_eq!(twiddles.len(), coset.size());
|
||||
for layer in &ifft::get_itwiddle_dbls(coset) {
|
||||
itwiddles.extend(layer);
|
||||
}
|
||||
// Pad by any value, to make the size a power of 2.
|
||||
itwiddles.push(1);
|
||||
assert_eq!(itwiddles.len(), coset.size());
|
||||
|
||||
TwiddleTree {
|
||||
root_coset: coset,
|
||||
twiddles,
|
||||
itwiddles,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn slow_eval_at_point(
|
||||
poly: &CirclePoly<SimdBackend>,
|
||||
point: CirclePoint<SecureField>,
|
||||
) -> SecureField {
|
||||
let mut mappings = vec![point.y, point.x];
|
||||
let mut x = point.x;
|
||||
for _ in 2..poly.log_size() {
|
||||
x = CirclePoint::double_x(x);
|
||||
mappings.push(x);
|
||||
}
|
||||
mappings.reverse();
|
||||
|
||||
// If the polynomial is large, the fft does a transpose in the middle.
|
||||
if poly.log_size() > CACHED_FFT_LOG_SIZE {
|
||||
let n = mappings.len();
|
||||
let n0 = (n - LOG_N_LANES as usize) / 2;
|
||||
let n1 = (n - LOG_N_LANES as usize + 1) / 2;
|
||||
let (ab, c) = mappings.split_at_mut(n1);
|
||||
let (a, _b) = ab.split_at_mut(n0);
|
||||
// Swap content of a,c.
|
||||
a.swap_with_slice(&mut c[0..n0]);
|
||||
}
|
||||
fold(cast_slice::<_, BaseField>(&poly.coeffs.data), &mappings)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
|
||||
use crate::core::backend::simd::circle::slow_eval_at_point;
|
||||
use crate::core::backend::simd::fft::{CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE};
|
||||
use crate::core::backend::simd::SimdBackend;
|
||||
use crate::core::backend::Column;
|
||||
use crate::core::circle::CirclePoint;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, CirclePoly, PolyOps};
|
||||
use crate::core::poly::{BitReversedOrder, NaturalOrder};
|
||||
|
||||
#[test]
|
||||
fn test_interpolate_and_eval() {
|
||||
for log_size in MIN_FFT_LOG_SIZE..CACHED_FFT_LOG_SIZE + 4 {
|
||||
let domain = CanonicCoset::new(log_size).circle_domain();
|
||||
let evaluation = CircleEvaluation::<SimdBackend, BaseField, BitReversedOrder>::new(
|
||||
domain,
|
||||
(0..1 << log_size).map(BaseField::from).collect(),
|
||||
);
|
||||
|
||||
let poly = evaluation.clone().interpolate();
|
||||
let evaluation2 = poly.evaluate(domain);
|
||||
|
||||
assert_eq!(evaluation.values.to_cpu(), evaluation2.values.to_cpu());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_eval_extension() {
|
||||
for log_size in MIN_FFT_LOG_SIZE..CACHED_FFT_LOG_SIZE + 2 {
|
||||
let domain = CanonicCoset::new(log_size).circle_domain();
|
||||
let domain_ext = CanonicCoset::new(log_size + 2).circle_domain();
|
||||
let evaluation = CircleEvaluation::<SimdBackend, BaseField, BitReversedOrder>::new(
|
||||
domain,
|
||||
(0..1 << log_size).map(BaseField::from).collect(),
|
||||
);
|
||||
let poly = evaluation.clone().interpolate();
|
||||
|
||||
let evaluation2 = poly.evaluate(domain_ext);
|
||||
|
||||
assert_eq!(
|
||||
poly.extend(log_size + 2).coeffs.to_cpu(),
|
||||
evaluation2.interpolate().coeffs.to_cpu()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_eval_at_point() {
|
||||
for log_size in MIN_FFT_LOG_SIZE + 1..CACHED_FFT_LOG_SIZE + 4 {
|
||||
let domain = CanonicCoset::new(log_size).circle_domain();
|
||||
let evaluation = CircleEvaluation::<SimdBackend, BaseField, NaturalOrder>::new(
|
||||
domain,
|
||||
(0..1 << log_size).map(BaseField::from).collect(),
|
||||
);
|
||||
let poly = evaluation.bit_reverse().interpolate();
|
||||
for i in [0, 1, 3, 1 << (log_size - 1), 1 << (log_size - 2)] {
|
||||
let p = domain.at(i);
|
||||
|
||||
let eval = poly.eval_at_point(p.into_ef());
|
||||
|
||||
assert_eq!(
|
||||
eval,
|
||||
BaseField::from(i).into(),
|
||||
"log_size={log_size}, i={i}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_circle_poly_extend() {
|
||||
for log_size in MIN_FFT_LOG_SIZE..CACHED_FFT_LOG_SIZE + 2 {
|
||||
let poly =
|
||||
CirclePoly::<SimdBackend>::new((0..1 << log_size).map(BaseField::from).collect());
|
||||
let eval0 = poly.evaluate(CanonicCoset::new(log_size + 2).circle_domain());
|
||||
|
||||
let eval1 = poly
|
||||
.extend(log_size + 2)
|
||||
.evaluate(CanonicCoset::new(log_size + 2).circle_domain());
|
||||
|
||||
assert_eq!(eval0.values.to_cpu(), eval1.values.to_cpu());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_eval_securefield() {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
for log_size in MIN_FFT_LOG_SIZE..CACHED_FFT_LOG_SIZE + 2 {
|
||||
let domain = CanonicCoset::new(log_size).circle_domain();
|
||||
let evaluation = CircleEvaluation::<SimdBackend, BaseField, NaturalOrder>::new(
|
||||
domain,
|
||||
(0..1 << log_size).map(BaseField::from).collect(),
|
||||
);
|
||||
let poly = evaluation.bit_reverse().interpolate();
|
||||
let x = rng.gen();
|
||||
let y = rng.gen();
|
||||
let p = CirclePoint { x, y };
|
||||
|
||||
let eval = PolyOps::eval_at_point(&poly, p);
|
||||
|
||||
assert_eq!(eval, slow_eval_at_point(&poly, p), "log_size = {log_size}");
|
||||
}
|
||||
}
|
||||
}
|
||||
230
Stwo_wrapper/crates/prover/src/core/backend/simd/cm31.rs
Normal file
230
Stwo_wrapper/crates/prover/src/core/backend/simd/cm31.rs
Normal file
@ -0,0 +1,230 @@
|
||||
use std::array;
|
||||
use std::ops::{Add, Mul, MulAssign, Neg, Sub};
|
||||
|
||||
use bytemuck::{Pod, Zeroable};
|
||||
use num_traits::{One, Zero};
|
||||
|
||||
use super::m31::{PackedM31, N_LANES};
|
||||
use crate::core::fields::cm31::CM31;
|
||||
use crate::core::fields::FieldExpOps;
|
||||
|
||||
/// SIMD implementation of [`CM31`].
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct PackedCM31(pub [PackedM31; 2]);
|
||||
|
||||
impl PackedCM31 {
|
||||
/// Constructs a new instance with all vector elements set to `value`.
|
||||
pub fn broadcast(value: CM31) -> Self {
|
||||
Self([PackedM31::broadcast(value.0), PackedM31::broadcast(value.1)])
|
||||
}
|
||||
|
||||
/// Returns all `a` values such that each vector element is represented as `a + bi`.
|
||||
pub fn a(&self) -> PackedM31 {
|
||||
self.0[0]
|
||||
}
|
||||
|
||||
/// Returns all `b` values such that each vector element is represented as `a + bi`.
|
||||
pub fn b(&self) -> PackedM31 {
|
||||
self.0[1]
|
||||
}
|
||||
|
||||
pub fn to_array(&self) -> [CM31; N_LANES] {
|
||||
let a = self.a().to_array();
|
||||
let b = self.b().to_array();
|
||||
array::from_fn(|i| CM31(a[i], b[i]))
|
||||
}
|
||||
|
||||
pub fn from_array(values: [CM31; N_LANES]) -> Self {
|
||||
Self([
|
||||
PackedM31::from_array(values.map(|v| v.0)),
|
||||
PackedM31::from_array(values.map(|v| v.1)),
|
||||
])
|
||||
}
|
||||
|
||||
/// Interleaves two vectors.
|
||||
pub fn interleave(self, other: Self) -> (Self, Self) {
|
||||
let Self([a_evens, b_evens]) = self;
|
||||
let Self([a_odds, b_odds]) = other;
|
||||
let (a_lhs, a_rhs) = a_evens.interleave(a_odds);
|
||||
let (b_lhs, b_rhs) = b_evens.interleave(b_odds);
|
||||
(Self([a_lhs, b_lhs]), Self([a_rhs, b_rhs]))
|
||||
}
|
||||
|
||||
/// Deinterleaves two vectors.
|
||||
pub fn deinterleave(self, other: Self) -> (Self, Self) {
|
||||
let Self([a_self, b_self]) = self;
|
||||
let Self([a_other, b_other]) = other;
|
||||
let (a_evens, a_odds) = a_self.deinterleave(a_other);
|
||||
let (b_evens, b_odds) = b_self.deinterleave(b_other);
|
||||
(Self([a_evens, b_evens]), Self([a_odds, b_odds]))
|
||||
}
|
||||
|
||||
/// Doubles each element in the vector.
|
||||
pub fn double(self) -> Self {
|
||||
let Self([a, b]) = self;
|
||||
Self([a.double(), b.double()])
|
||||
}
|
||||
}
|
||||
|
||||
impl Add for PackedCM31 {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
Self([self.a() + rhs.a(), self.b() + rhs.b()])
|
||||
}
|
||||
}
|
||||
|
||||
impl Sub for PackedCM31 {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
Self([self.a() - rhs.a(), self.b() - rhs.b()])
|
||||
}
|
||||
}
|
||||
|
||||
impl Mul for PackedCM31 {
|
||||
type Output = Self;
|
||||
|
||||
fn mul(self, rhs: Self) -> Self::Output {
|
||||
// Compute using Karatsuba.
|
||||
let ac = self.a() * rhs.a();
|
||||
let bd = self.b() * rhs.b();
|
||||
// Computes (a + b) * (c + d).
|
||||
let ab_t_cd = (self.a() + self.b()) * (rhs.a() + rhs.b());
|
||||
// (ac - bd) + (ad + bc)i.
|
||||
Self([ac - bd, ab_t_cd - ac - bd])
|
||||
}
|
||||
}
|
||||
|
||||
impl Zero for PackedCM31 {
|
||||
fn zero() -> Self {
|
||||
Self([PackedM31::zero(), PackedM31::zero()])
|
||||
}
|
||||
|
||||
fn is_zero(&self) -> bool {
|
||||
self.a().is_zero() && self.b().is_zero()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl Pod for PackedCM31 {}
|
||||
|
||||
unsafe impl Zeroable for PackedCM31 {
|
||||
fn zeroed() -> Self {
|
||||
unsafe { core::mem::zeroed() }
|
||||
}
|
||||
}
|
||||
|
||||
impl One for PackedCM31 {
|
||||
fn one() -> Self {
|
||||
Self([PackedM31::one(), PackedM31::zero()])
|
||||
}
|
||||
}
|
||||
|
||||
impl MulAssign for PackedCM31 {
|
||||
fn mul_assign(&mut self, rhs: Self) {
|
||||
*self = *self * rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl FieldExpOps for PackedCM31 {
|
||||
fn inverse(&self) -> Self {
|
||||
assert!(!self.is_zero(), "0 has no inverse");
|
||||
// 1 / (a + bi) = (a - bi) / (a^2 + b^2).
|
||||
Self([self.a(), -self.b()]) * (self.a().square() + self.b().square()).inverse()
|
||||
}
|
||||
}
|
||||
|
||||
impl Add<PackedM31> for PackedCM31 {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: PackedM31) -> Self::Output {
|
||||
Self([self.a() + rhs, self.b()])
|
||||
}
|
||||
}
|
||||
|
||||
impl Sub<PackedM31> for PackedCM31 {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, rhs: PackedM31) -> Self::Output {
|
||||
let Self([a, b]) = self;
|
||||
Self([a - rhs, b])
|
||||
}
|
||||
}
|
||||
|
||||
impl Mul<PackedM31> for PackedCM31 {
|
||||
type Output = Self;
|
||||
|
||||
fn mul(self, rhs: PackedM31) -> Self::Output {
|
||||
let Self([a, b]) = self;
|
||||
Self([a * rhs, b * rhs])
|
||||
}
|
||||
}
|
||||
|
||||
impl Neg for PackedCM31 {
|
||||
type Output = Self;
|
||||
|
||||
fn neg(self) -> Self::Output {
|
||||
let Self([a, b]) = self;
|
||||
Self([-a, -b])
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::array;
|
||||
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
|
||||
use crate::core::backend::simd::cm31::PackedCM31;
|
||||
|
||||
#[test]
|
||||
fn addition_works() {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let lhs = rng.gen();
|
||||
let rhs = rng.gen();
|
||||
let packed_lhs = PackedCM31::from_array(lhs);
|
||||
let packed_rhs = PackedCM31::from_array(rhs);
|
||||
|
||||
let res = packed_lhs + packed_rhs;
|
||||
|
||||
assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] + rhs[i]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn subtraction_works() {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let lhs = rng.gen();
|
||||
let rhs = rng.gen();
|
||||
let packed_lhs = PackedCM31::from_array(lhs);
|
||||
let packed_rhs = PackedCM31::from_array(rhs);
|
||||
|
||||
let res = packed_lhs - packed_rhs;
|
||||
|
||||
assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] - rhs[i]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multiplication_works() {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let lhs = rng.gen();
|
||||
let rhs = rng.gen();
|
||||
let packed_lhs = PackedCM31::from_array(lhs);
|
||||
let packed_rhs = PackedCM31::from_array(rhs);
|
||||
|
||||
let res = packed_lhs * packed_rhs;
|
||||
|
||||
assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] * rhs[i]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn negation_works() {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let values = rng.gen();
|
||||
let packed_values = PackedCM31::from_array(values);
|
||||
|
||||
let res = -packed_values;
|
||||
|
||||
assert_eq!(res.to_array(), values.map(|v| -v));
|
||||
}
|
||||
}
|
||||
656
Stwo_wrapper/crates/prover/src/core/backend/simd/column.rs
Normal file
656
Stwo_wrapper/crates/prover/src/core/backend/simd/column.rs
Normal file
@ -0,0 +1,656 @@
|
||||
use std::iter::zip;
|
||||
use std::{array, mem};
|
||||
|
||||
use bytemuck::allocation::cast_vec;
|
||||
use bytemuck::{cast_slice, cast_slice_mut, Zeroable};
|
||||
use itertools::{izip, Itertools};
|
||||
use num_traits::Zero;
|
||||
|
||||
use super::cm31::PackedCM31;
|
||||
use super::m31::{PackedBaseField, N_LANES};
|
||||
use super::qm31::{PackedQM31, PackedSecureField};
|
||||
use super::very_packed_m31::{VeryPackedBaseField, VeryPackedSecureField, N_VERY_PACKED_ELEMS};
|
||||
use super::SimdBackend;
|
||||
use crate::core::backend::{Column, CpuBackend};
|
||||
use crate::core::fields::cm31::CM31;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::fields::secure_column::{SecureColumnByCoords, SECURE_EXTENSION_DEGREE};
|
||||
use crate::core::fields::{FieldExpOps, FieldOps};
|
||||
|
||||
impl FieldOps<BaseField> for SimdBackend {
|
||||
fn batch_inverse(column: &BaseColumn, dst: &mut BaseColumn) {
|
||||
PackedBaseField::batch_inverse(&column.data, &mut dst.data);
|
||||
}
|
||||
}
|
||||
|
||||
impl FieldOps<SecureField> for SimdBackend {
|
||||
fn batch_inverse(column: &SecureColumn, dst: &mut SecureColumn) {
|
||||
PackedSecureField::batch_inverse(&column.data, &mut dst.data);
|
||||
}
|
||||
}
|
||||
|
||||
/// An efficient structure for storing and operating on a arbitrary number of [`BaseField`] values.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct BaseColumn {
|
||||
pub data: Vec<PackedBaseField>,
|
||||
/// The number of [`BaseField`]s in the vector.
|
||||
pub length: usize,
|
||||
}
|
||||
|
||||
impl BaseColumn {
|
||||
/// Extracts a slice containing the entire vector of [`BaseField`]s.
|
||||
pub fn as_slice(&self) -> &[BaseField] {
|
||||
&cast_slice(&self.data)[..self.length]
|
||||
}
|
||||
|
||||
/// Extracts a mutable slice containing the entire vector of [`BaseField`]s.
|
||||
pub fn as_mut_slice(&mut self) -> &mut [BaseField] {
|
||||
&mut cast_slice_mut(&mut self.data)[..self.length]
|
||||
}
|
||||
|
||||
pub fn into_cpu_vec(mut self) -> Vec<BaseField> {
|
||||
let capacity = self.data.capacity() * N_LANES;
|
||||
let length = self.length;
|
||||
let ptr = self.data.as_mut_ptr() as *mut BaseField;
|
||||
let res = unsafe { Vec::from_raw_parts(ptr, length, capacity) };
|
||||
mem::forget(self);
|
||||
res
|
||||
}
|
||||
|
||||
/// Returns a vector of `BaseColumnMutSlice`s, each mutably owning
|
||||
/// `chunk_size` `PackedBaseField`s (i.e, `chuck_size` * `N_LANES` elements).
|
||||
pub fn chunks_mut(&mut self, chunk_size: usize) -> Vec<BaseColumnMutSlice<'_>> {
|
||||
self.data
|
||||
.chunks_mut(chunk_size)
|
||||
.map(BaseColumnMutSlice)
|
||||
.collect_vec()
|
||||
}
|
||||
|
||||
pub fn into_secure_column(self) -> SecureColumn {
|
||||
let length = self.len();
|
||||
let data = self.data.into_iter().map(PackedSecureField::from).collect();
|
||||
SecureColumn { data, length }
|
||||
}
|
||||
}
|
||||
|
||||
impl Column<BaseField> for BaseColumn {
|
||||
fn zeros(length: usize) -> Self {
|
||||
let data = vec![PackedBaseField::zeroed(); length.div_ceil(N_LANES)];
|
||||
Self { data, length }
|
||||
}
|
||||
|
||||
#[allow(clippy::uninit_vec)]
|
||||
unsafe fn uninitialized(length: usize) -> Self {
|
||||
let mut data = Vec::with_capacity(length.div_ceil(N_LANES));
|
||||
data.set_len(length.div_ceil(N_LANES));
|
||||
Self { data, length }
|
||||
}
|
||||
|
||||
fn to_cpu(&self) -> Vec<BaseField> {
|
||||
self.as_slice().to_vec()
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.length
|
||||
}
|
||||
|
||||
fn at(&self, index: usize) -> BaseField {
|
||||
self.data[index / N_LANES].to_array()[index % N_LANES]
|
||||
}
|
||||
|
||||
fn set(&mut self, index: usize, value: BaseField) {
|
||||
let mut packed = self.data[index / N_LANES].to_array();
|
||||
packed[index % N_LANES] = value;
|
||||
self.data[index / N_LANES] = PackedBaseField::from_array(packed)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromIterator<BaseField> for BaseColumn {
|
||||
fn from_iter<I: IntoIterator<Item = BaseField>>(iter: I) -> Self {
|
||||
let mut chunks = iter.into_iter().array_chunks();
|
||||
let mut data = (&mut chunks).map(PackedBaseField::from_array).collect_vec();
|
||||
let mut length = data.len() * N_LANES;
|
||||
|
||||
if let Some(remainder) = chunks.into_remainder() {
|
||||
if !remainder.is_empty() {
|
||||
length += remainder.len();
|
||||
let mut last = [BaseField::zero(); N_LANES];
|
||||
last[..remainder.len()].copy_from_slice(remainder.as_slice());
|
||||
data.push(PackedBaseField::from_array(last));
|
||||
}
|
||||
}
|
||||
|
||||
Self { data, length }
|
||||
}
|
||||
}
|
||||
|
||||
// A efficient structure for storing and operating on a arbitrary number of [`SecureField`] values.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CM31Column {
|
||||
pub data: Vec<PackedCM31>,
|
||||
pub length: usize,
|
||||
}
|
||||
|
||||
impl Column<CM31> for CM31Column {
|
||||
fn zeros(length: usize) -> Self {
|
||||
Self {
|
||||
data: vec![PackedCM31::zeroed(); length.div_ceil(N_LANES)],
|
||||
length,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::uninit_vec)]
|
||||
unsafe fn uninitialized(length: usize) -> Self {
|
||||
let mut data = Vec::with_capacity(length.div_ceil(N_LANES));
|
||||
data.set_len(length.div_ceil(N_LANES));
|
||||
Self { data, length }
|
||||
}
|
||||
|
||||
fn to_cpu(&self) -> Vec<CM31> {
|
||||
self.data
|
||||
.iter()
|
||||
.flat_map(|x| x.to_array())
|
||||
.take(self.length)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.length
|
||||
}
|
||||
|
||||
fn at(&self, index: usize) -> CM31 {
|
||||
self.data[index / N_LANES].to_array()[index % N_LANES]
|
||||
}
|
||||
|
||||
fn set(&mut self, index: usize, value: CM31) {
|
||||
let mut packed = self.data[index / N_LANES].to_array();
|
||||
packed[index % N_LANES] = value;
|
||||
self.data[index / N_LANES] = PackedCM31::from_array(packed)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromIterator<CM31> for CM31Column {
|
||||
fn from_iter<I: IntoIterator<Item = CM31>>(iter: I) -> Self {
|
||||
let mut chunks = iter.into_iter().array_chunks();
|
||||
let mut data = (&mut chunks).map(PackedCM31::from_array).collect_vec();
|
||||
let mut length = data.len() * N_LANES;
|
||||
|
||||
if let Some(remainder) = chunks.into_remainder() {
|
||||
if !remainder.is_empty() {
|
||||
length += remainder.len();
|
||||
let mut last = [CM31::zero(); N_LANES];
|
||||
last[..remainder.len()].copy_from_slice(remainder.as_slice());
|
||||
data.push(PackedCM31::from_array(last));
|
||||
}
|
||||
}
|
||||
|
||||
Self { data, length }
|
||||
}
|
||||
}
|
||||
|
||||
impl FromIterator<PackedCM31> for CM31Column {
|
||||
fn from_iter<I: IntoIterator<Item = PackedCM31>>(iter: I) -> Self {
|
||||
let data = (&mut iter.into_iter()).collect_vec();
|
||||
let length = data.len() * N_LANES;
|
||||
|
||||
Self { data, length }
|
||||
}
|
||||
}
|
||||
|
||||
/// A mutable slice of a BaseColumn.
|
||||
pub struct BaseColumnMutSlice<'a>(pub &'a mut [PackedBaseField]);
|
||||
|
||||
impl<'a> BaseColumnMutSlice<'a> {
|
||||
pub fn at(&self, index: usize) -> BaseField {
|
||||
self.0[index / N_LANES].to_array()[index % N_LANES]
|
||||
}
|
||||
|
||||
pub fn set(&mut self, index: usize, value: BaseField) {
|
||||
let mut packed = self.0[index / N_LANES].to_array();
|
||||
packed[index % N_LANES] = value;
|
||||
self.0[index / N_LANES] = PackedBaseField::from_array(packed)
|
||||
}
|
||||
}
|
||||
|
||||
/// An efficient structure for storing and operating on a arbitrary number of [`SecureField`]
|
||||
/// values.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct SecureColumn {
|
||||
pub data: Vec<PackedSecureField>,
|
||||
/// The number of [`SecureField`]s in the vector.
|
||||
pub length: usize,
|
||||
}
|
||||
|
||||
impl SecureColumn {
|
||||
// Separates a single column of `PackedSecureField` elements into `SECURE_EXTENSION_DEGREE` many
|
||||
// `PackedBaseField` coordinate columns.
|
||||
pub fn into_secure_column_by_coords(self) -> SecureColumnByCoords<SimdBackend> {
|
||||
if self.len() < N_LANES {
|
||||
return self.to_cpu().into_iter().collect();
|
||||
}
|
||||
|
||||
let length = self.length;
|
||||
let packed_length = self.data.len();
|
||||
let mut columns = array::from_fn(|_| Vec::with_capacity(packed_length));
|
||||
|
||||
for v in self.data {
|
||||
let packed_coords = v.into_packed_m31s();
|
||||
zip(&mut columns, packed_coords).for_each(|(col, packed_coord)| col.push(packed_coord));
|
||||
}
|
||||
|
||||
SecureColumnByCoords {
|
||||
columns: columns.map(|col| BaseColumn { data: col, length }),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Column<SecureField> for SecureColumn {
|
||||
fn zeros(length: usize) -> Self {
|
||||
Self {
|
||||
data: vec![PackedSecureField::zeroed(); length.div_ceil(N_LANES)],
|
||||
length,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::uninit_vec)]
|
||||
unsafe fn uninitialized(length: usize) -> Self {
|
||||
let mut data = Vec::with_capacity(length.div_ceil(N_LANES));
|
||||
data.set_len(length.div_ceil(N_LANES));
|
||||
Self { data, length }
|
||||
}
|
||||
|
||||
fn to_cpu(&self) -> Vec<SecureField> {
|
||||
self.data
|
||||
.iter()
|
||||
.flat_map(|x| x.to_array())
|
||||
.take(self.length)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.length
|
||||
}
|
||||
|
||||
fn at(&self, index: usize) -> SecureField {
|
||||
self.data[index / N_LANES].to_array()[index % N_LANES]
|
||||
}
|
||||
|
||||
fn set(&mut self, index: usize, value: SecureField) {
|
||||
let mut packed = self.data[index / N_LANES].to_array();
|
||||
packed[index % N_LANES] = value;
|
||||
self.data[index / N_LANES] = PackedSecureField::from_array(packed)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromIterator<SecureField> for SecureColumn {
|
||||
fn from_iter<I: IntoIterator<Item = SecureField>>(iter: I) -> Self {
|
||||
let mut chunks = iter.into_iter().array_chunks();
|
||||
let mut data = (&mut chunks)
|
||||
.map(PackedSecureField::from_array)
|
||||
.collect_vec();
|
||||
let mut length = data.len() * N_LANES;
|
||||
|
||||
if let Some(remainder) = chunks.into_remainder() {
|
||||
if !remainder.is_empty() {
|
||||
length += remainder.len();
|
||||
let mut last = [SecureField::zero(); N_LANES];
|
||||
last[..remainder.len()].copy_from_slice(remainder.as_slice());
|
||||
data.push(PackedSecureField::from_array(last));
|
||||
}
|
||||
}
|
||||
|
||||
Self { data, length }
|
||||
}
|
||||
}
|
||||
|
||||
impl FromIterator<PackedSecureField> for SecureColumn {
|
||||
fn from_iter<I: IntoIterator<Item = PackedSecureField>>(iter: I) -> Self {
|
||||
let data = iter.into_iter().collect_vec();
|
||||
let length = data.len() * N_LANES;
|
||||
Self { data, length }
|
||||
}
|
||||
}
|
||||
|
||||
/// A mutable slice of a SecureColumnByCoords.
|
||||
pub struct SecureColumnByCoordsMutSlice<'a>(pub [BaseColumnMutSlice<'a>; SECURE_EXTENSION_DEGREE]);
|
||||
|
||||
impl<'a> SecureColumnByCoordsMutSlice<'a> {
|
||||
/// # Safety
|
||||
///
|
||||
/// `vec_index` must be a valid index.
|
||||
pub unsafe fn packed_at(&self, vec_index: usize) -> PackedSecureField {
|
||||
PackedQM31([
|
||||
PackedCM31([
|
||||
*self.0[0].0.get_unchecked(vec_index),
|
||||
*self.0[1].0.get_unchecked(vec_index),
|
||||
]),
|
||||
PackedCM31([
|
||||
*self.0[2].0.get_unchecked(vec_index),
|
||||
*self.0[3].0.get_unchecked(vec_index),
|
||||
]),
|
||||
])
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// `vec_index` must be a valid index.
|
||||
pub unsafe fn set_packed(&mut self, vec_index: usize, value: PackedSecureField) {
|
||||
let PackedQM31([PackedCM31([a, b]), PackedCM31([c, d])]) = value;
|
||||
*self.0[0].0.get_unchecked_mut(vec_index) = a;
|
||||
*self.0[1].0.get_unchecked_mut(vec_index) = b;
|
||||
*self.0[2].0.get_unchecked_mut(vec_index) = c;
|
||||
*self.0[3].0.get_unchecked_mut(vec_index) = d;
|
||||
}
|
||||
}
|
||||
|
||||
impl SecureColumnByCoords<SimdBackend> {
|
||||
pub fn packed_len(&self) -> usize {
|
||||
self.columns[0].data.len()
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// `vec_index` must be a valid index.
|
||||
pub unsafe fn packed_at(&self, vec_index: usize) -> PackedSecureField {
|
||||
PackedQM31([
|
||||
PackedCM31([
|
||||
*self.columns[0].data.get_unchecked(vec_index),
|
||||
*self.columns[1].data.get_unchecked(vec_index),
|
||||
]),
|
||||
PackedCM31([
|
||||
*self.columns[2].data.get_unchecked(vec_index),
|
||||
*self.columns[3].data.get_unchecked(vec_index),
|
||||
]),
|
||||
])
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// `vec_index` must be a valid index.
|
||||
pub unsafe fn set_packed(&mut self, vec_index: usize, value: PackedSecureField) {
|
||||
let PackedQM31([PackedCM31([a, b]), PackedCM31([c, d])]) = value;
|
||||
*self.columns[0].data.get_unchecked_mut(vec_index) = a;
|
||||
*self.columns[1].data.get_unchecked_mut(vec_index) = b;
|
||||
*self.columns[2].data.get_unchecked_mut(vec_index) = c;
|
||||
*self.columns[3].data.get_unchecked_mut(vec_index) = d;
|
||||
}
|
||||
|
||||
pub fn to_vec(&self) -> Vec<SecureField> {
|
||||
izip!(
|
||||
self.columns[0].to_cpu(),
|
||||
self.columns[1].to_cpu(),
|
||||
self.columns[2].to_cpu(),
|
||||
self.columns[3].to_cpu(),
|
||||
)
|
||||
.map(|(a, b, c, d)| SecureField::from_m31_array([a, b, c, d]))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Returns a vector of `SecureColumnByCoordsMutSlice`s, each mutably owning
|
||||
/// `SECURE_EXTENSION_DEGREE` slices of `chunk_size` `PackedBaseField`s
|
||||
/// (i.e, `chuck_size` * `N_LANES` secure field elements, by coordinates).
|
||||
pub fn chunks_mut(&mut self, chunk_size: usize) -> Vec<SecureColumnByCoordsMutSlice<'_>> {
|
||||
let [a, b, c, d] = self
|
||||
.columns
|
||||
.get_many_mut([0, 1, 2, 3])
|
||||
.unwrap()
|
||||
.map(|x| x.chunks_mut(chunk_size));
|
||||
izip!(a, b, c, d)
|
||||
.map(|(a, b, c, d)| SecureColumnByCoordsMutSlice([a, b, c, d]))
|
||||
.collect_vec()
|
||||
}
|
||||
}
|
||||
|
||||
impl FromIterator<SecureField> for SecureColumnByCoords<SimdBackend> {
|
||||
fn from_iter<I: IntoIterator<Item = SecureField>>(iter: I) -> Self {
|
||||
let cpu_col = SecureColumnByCoords::<CpuBackend>::from_iter(iter);
|
||||
let columns = cpu_col.columns.map(|col| col.into_iter().collect());
|
||||
SecureColumnByCoords { columns }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct VeryPackedBaseColumn {
|
||||
pub data: Vec<VeryPackedBaseField>,
|
||||
/// The number of [`BaseField`]s in the vector.
|
||||
pub length: usize,
|
||||
}
|
||||
|
||||
impl VeryPackedBaseColumn {
|
||||
/// Transforms a `&BaseColumn` to a `&VeryPackedBaseColumn`.
|
||||
/// # Safety
|
||||
///
|
||||
/// The resulting pointer does not update the underlying `data`'s length.
|
||||
pub unsafe fn transform_under_ref(value: &BaseColumn) -> &Self {
|
||||
&*(std::ptr::addr_of!(*value) as *const VeryPackedBaseColumn)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<BaseColumn> for VeryPackedBaseColumn {
|
||||
fn from(value: BaseColumn) -> Self {
|
||||
Self {
|
||||
data: cast_vec(value.data),
|
||||
length: value.length,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<VeryPackedBaseColumn> for BaseColumn {
|
||||
fn from(value: VeryPackedBaseColumn) -> Self {
|
||||
Self {
|
||||
data: cast_vec(value.data),
|
||||
length: value.length,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromIterator<BaseField> for VeryPackedBaseColumn {
|
||||
fn from_iter<I: IntoIterator<Item = BaseField>>(iter: I) -> Self {
|
||||
BaseColumn::from_iter(iter).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl Column<BaseField> for VeryPackedBaseColumn {
|
||||
fn zeros(length: usize) -> Self {
|
||||
BaseColumn::zeros(length).into()
|
||||
}
|
||||
|
||||
#[allow(clippy::uninit_vec)]
|
||||
unsafe fn uninitialized(length: usize) -> Self {
|
||||
BaseColumn::uninitialized(length).into()
|
||||
}
|
||||
|
||||
fn to_cpu(&self) -> Vec<BaseField> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.length
|
||||
}
|
||||
|
||||
fn at(&self, index: usize) -> BaseField {
|
||||
let chunk_size = N_LANES * N_VERY_PACKED_ELEMS;
|
||||
self.data[index / chunk_size].to_array()[index % chunk_size]
|
||||
}
|
||||
|
||||
fn set(&mut self, index: usize, value: BaseField) {
|
||||
let chunk_size = N_LANES * N_VERY_PACKED_ELEMS;
|
||||
let mut packed = self.data[index / chunk_size].to_array();
|
||||
packed[index % chunk_size] = value;
|
||||
self.data[index / chunk_size] = VeryPackedBaseField::from_array(packed)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct VeryPackedSecureColumnByCoords {
|
||||
pub columns: [VeryPackedBaseColumn; SECURE_EXTENSION_DEGREE],
|
||||
}
|
||||
|
||||
impl From<SecureColumnByCoords<SimdBackend>> for VeryPackedSecureColumnByCoords {
|
||||
fn from(value: SecureColumnByCoords<SimdBackend>) -> Self {
|
||||
Self {
|
||||
columns: value
|
||||
.columns
|
||||
.into_iter()
|
||||
.map(VeryPackedBaseColumn::from)
|
||||
.collect_vec()
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<VeryPackedSecureColumnByCoords> for SecureColumnByCoords<SimdBackend> {
|
||||
fn from(value: VeryPackedSecureColumnByCoords) -> Self {
|
||||
Self {
|
||||
columns: value
|
||||
.columns
|
||||
.into_iter()
|
||||
.map(BaseColumn::from)
|
||||
.collect_vec()
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl VeryPackedSecureColumnByCoords {
|
||||
pub fn packed_len(&self) -> usize {
|
||||
self.columns[0].data.len()
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// `vec_index` must be a valid index.
|
||||
pub unsafe fn packed_at(&self, vec_index: usize) -> VeryPackedSecureField {
|
||||
VeryPackedSecureField::from_fn(|i| {
|
||||
PackedQM31([
|
||||
PackedCM31([
|
||||
self.columns[0].data.get_unchecked(vec_index).0[i],
|
||||
self.columns[1].data.get_unchecked(vec_index).0[i],
|
||||
]),
|
||||
PackedCM31([
|
||||
self.columns[2].data.get_unchecked(vec_index).0[i],
|
||||
self.columns[3].data.get_unchecked(vec_index).0[i],
|
||||
]),
|
||||
])
|
||||
})
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// `vec_index` must be a valid index.
|
||||
pub unsafe fn set_packed(&mut self, vec_index: usize, value: VeryPackedSecureField) {
|
||||
for i in 0..N_VERY_PACKED_ELEMS {
|
||||
let PackedQM31([PackedCM31([a, b]), PackedCM31([c, d])]) = value.0[i];
|
||||
self.columns[0].data.get_unchecked_mut(vec_index).0[i] = a;
|
||||
self.columns[1].data.get_unchecked_mut(vec_index).0[i] = b;
|
||||
self.columns[2].data.get_unchecked_mut(vec_index).0[i] = c;
|
||||
self.columns[3].data.get_unchecked_mut(vec_index).0[i] = d;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_vec(&self) -> Vec<SecureField> {
|
||||
izip!(
|
||||
self.columns[0].to_cpu(),
|
||||
self.columns[1].to_cpu(),
|
||||
self.columns[2].to_cpu(),
|
||||
self.columns[3].to_cpu(),
|
||||
)
|
||||
.map(|(a, b, c, d)| SecureField::from_m31_array([a, b, c, d]))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Transforms a `&mut SecureColumnByCoords<SimdBackend>` to a
|
||||
/// `&mut VeryPackedSecureColumnByCoords`.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The resulting pointer does not update the underlying columns' `data`'s lengths.
|
||||
pub unsafe fn transform_under_mut(value: &mut SecureColumnByCoords<SimdBackend>) -> &mut Self {
|
||||
&mut *(std::ptr::addr_of!(*value) as *mut VeryPackedSecureColumnByCoords)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::array;
|
||||
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
|
||||
use super::BaseColumn;
|
||||
use crate::core::backend::simd::column::SecureColumn;
|
||||
use crate::core::backend::simd::m31::N_LANES;
|
||||
use crate::core::backend::simd::qm31::PackedQM31;
|
||||
use crate::core::backend::Column;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::fields::secure_column::SecureColumnByCoords;
|
||||
|
||||
#[test]
|
||||
fn base_field_vec_from_iter_works() {
|
||||
let values: [BaseField; 30] = array::from_fn(BaseField::from);
|
||||
|
||||
let res = values.into_iter().collect::<BaseColumn>();
|
||||
|
||||
assert_eq!(res.to_cpu(), values);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn secure_field_vec_from_iter_works() {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let values: [SecureField; 30] = rng.gen();
|
||||
|
||||
let res = values.into_iter().collect::<SecureColumn>();
|
||||
|
||||
assert_eq!(res.to_cpu(), values);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_base_column_chunks_mut() {
|
||||
let values: [BaseField; N_LANES * 7] = array::from_fn(BaseField::from);
|
||||
let mut col = values.into_iter().collect::<BaseColumn>();
|
||||
|
||||
const CHUNK_SIZE: usize = 2;
|
||||
let mut chunks = col.chunks_mut(CHUNK_SIZE);
|
||||
chunks[2].set(19, BaseField::from(1234));
|
||||
chunks[3].set(1, BaseField::from(5678));
|
||||
|
||||
assert_eq!(col.at(2 * CHUNK_SIZE * N_LANES + 19), BaseField::from(1234));
|
||||
assert_eq!(col.at(3 * CHUNK_SIZE * N_LANES + 1), BaseField::from(5678));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_secure_column_by_coords_chunks_mut() {
|
||||
const COL_PACKED_SIZE: usize = 16;
|
||||
let a: [BaseField; N_LANES * COL_PACKED_SIZE] = array::from_fn(BaseField::from);
|
||||
let b: [BaseField; N_LANES * COL_PACKED_SIZE] = array::from_fn(BaseField::from);
|
||||
let c: [BaseField; N_LANES * COL_PACKED_SIZE] = array::from_fn(BaseField::from);
|
||||
let d: [BaseField; N_LANES * COL_PACKED_SIZE] = array::from_fn(BaseField::from);
|
||||
let mut col = SecureColumnByCoords {
|
||||
columns: [a, b, c, d].map(|values| values.into_iter().collect::<BaseColumn>()),
|
||||
};
|
||||
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let rand0 = PackedQM31::from_array(rng.gen());
|
||||
let rand1 = PackedQM31::from_array(rng.gen());
|
||||
|
||||
const CHUNK_SIZE: usize = 4;
|
||||
let mut chunks = col.chunks_mut(CHUNK_SIZE);
|
||||
unsafe {
|
||||
chunks[2].set_packed(3, rand0);
|
||||
chunks[3].set_packed(1, rand1);
|
||||
|
||||
assert_eq!(
|
||||
col.packed_at(2 * CHUNK_SIZE + 3).to_array(),
|
||||
rand0.to_array()
|
||||
);
|
||||
assert_eq!(
|
||||
col.packed_at(3 * CHUNK_SIZE + 1).to_array(),
|
||||
rand1.to_array()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
86
Stwo_wrapper/crates/prover/src/core/backend/simd/domain.rs
Normal file
86
Stwo_wrapper/crates/prover/src/core/backend/simd/domain.rs
Normal file
@ -0,0 +1,86 @@
|
||||
use std::simd::{simd_swizzle, u32x2, Simd};
|
||||
|
||||
use super::m31::{PackedM31, LOG_N_LANES};
|
||||
use crate::core::circle::{CirclePoint, M31_CIRCLE_LOG_ORDER};
|
||||
use crate::core::fields::m31::M31;
|
||||
use crate::core::poly::circle::CircleDomain;
|
||||
use crate::core::utils::bit_reverse_index;
|
||||
|
||||
pub struct CircleDomainBitRevIterator {
|
||||
domain: CircleDomain,
|
||||
i: usize,
|
||||
current: CirclePoint<PackedM31>,
|
||||
flips: [CirclePoint<M31>; (M31_CIRCLE_LOG_ORDER - LOG_N_LANES) as usize],
|
||||
}
|
||||
impl CircleDomainBitRevIterator {
|
||||
pub fn new(domain: CircleDomain) -> Self {
|
||||
let log_size = domain.log_size();
|
||||
assert!(log_size >= LOG_N_LANES);
|
||||
|
||||
let initial_points = std::array::from_fn(|i| domain.at(bit_reverse_index(i, log_size)));
|
||||
let current = CirclePoint {
|
||||
x: PackedM31::from_array(initial_points.each_ref().map(|p| p.x)),
|
||||
y: PackedM31::from_array(initial_points.each_ref().map(|p| p.y)),
|
||||
};
|
||||
|
||||
let mut flips = [CirclePoint::zero(); (M31_CIRCLE_LOG_ORDER - LOG_N_LANES) as usize];
|
||||
for i in 0..(log_size - LOG_N_LANES) {
|
||||
// L i
|
||||
// 000111000000 ->
|
||||
// 000000100000
|
||||
let prev_mul = bit_reverse_index((1 << i) - 1, log_size - LOG_N_LANES);
|
||||
let new_mul = bit_reverse_index(1 << i, log_size - LOG_N_LANES);
|
||||
let flip = domain.half_coset.step.mul(new_mul as u128)
|
||||
- domain.half_coset.step.mul(prev_mul as u128);
|
||||
flips[i as usize] = flip;
|
||||
}
|
||||
Self {
|
||||
domain,
|
||||
i: 0,
|
||||
current,
|
||||
flips,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl Iterator for CircleDomainBitRevIterator {
|
||||
type Item = CirclePoint<PackedM31>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.i << LOG_N_LANES >= self.domain.size() {
|
||||
return None;
|
||||
}
|
||||
let res = self.current;
|
||||
let flip = self.flips[self.i.trailing_ones() as usize];
|
||||
let flipx = Simd::splat(flip.x.0);
|
||||
let flipy = u32x2::from_array([flip.y.0, (-flip.y).0]);
|
||||
let flipy = simd_swizzle!(flipy, [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]);
|
||||
let flip = unsafe {
|
||||
CirclePoint {
|
||||
x: PackedM31::from_simd_unchecked(flipx),
|
||||
y: PackedM31::from_simd_unchecked(flipy),
|
||||
}
|
||||
};
|
||||
self.current = self.current + flip;
|
||||
self.i += 1;
|
||||
Some(res)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_circle_domain_bit_rev_iterator() {
|
||||
let domain = CircleDomain::new(crate::core::circle::Coset::new(
|
||||
crate::core::circle::CirclePointIndex::generator(),
|
||||
5,
|
||||
));
|
||||
let mut expected = domain.iter().collect::<Vec<_>>();
|
||||
crate::core::utils::bit_reverse(&mut expected);
|
||||
let actual = CircleDomainBitRevIterator::new(domain)
|
||||
.flat_map(|c| -> [_; 16] {
|
||||
std::array::from_fn(|i| CirclePoint {
|
||||
x: c.x.to_array()[i],
|
||||
y: c.y.to_array()[i],
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(actual, expected);
|
||||
}
|
||||
712
Stwo_wrapper/crates/prover/src/core/backend/simd/fft/ifft.rs
Normal file
712
Stwo_wrapper/crates/prover/src/core/backend/simd/fft/ifft.rs
Normal file
@ -0,0 +1,712 @@
|
||||
//! Inverse fft.
|
||||
|
||||
use std::simd::{simd_swizzle, u32x16, u32x2, u32x4};
|
||||
|
||||
use itertools::Itertools;
|
||||
|
||||
use super::{
|
||||
compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE,
|
||||
};
|
||||
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
|
||||
use crate::core::circle::Coset;
|
||||
use crate::core::fields::FieldExpOps;
|
||||
use crate::core::utils::bit_reverse;
|
||||
|
||||
/// Performs an Inverse Circle Fast Fourier Transform (ICFFT) on the given values.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `values`: A mutable pointer to the values on which the ICFFT is to be performed.
|
||||
/// - `twiddle_dbl`: A reference to the doubles of the twiddle factors.
|
||||
/// - `log_n_elements`: The log of the number of elements in the `values` array.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panic if `log_n_elements` is less than [`MIN_FFT_LOG_SIZE`].
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`].
|
||||
pub unsafe fn ifft(values: *mut u32, twiddle_dbl: &[&[u32]], log_n_elements: usize) {
|
||||
assert!(log_n_elements >= MIN_FFT_LOG_SIZE as usize);
|
||||
let log_n_vecs = log_n_elements - LOG_N_LANES as usize;
|
||||
if log_n_elements <= CACHED_FFT_LOG_SIZE as usize {
|
||||
ifft_lower_with_vecwise(values, twiddle_dbl, log_n_elements, log_n_elements);
|
||||
return;
|
||||
}
|
||||
|
||||
let fft_layers_pre_transpose = log_n_vecs.div_ceil(2);
|
||||
let fft_layers_post_transpose = log_n_vecs / 2;
|
||||
ifft_lower_with_vecwise(
|
||||
values,
|
||||
&twiddle_dbl[..3 + fft_layers_pre_transpose],
|
||||
log_n_elements,
|
||||
fft_layers_pre_transpose + LOG_N_LANES as usize,
|
||||
);
|
||||
transpose_vecs(values, log_n_vecs);
|
||||
ifft_lower_without_vecwise(
|
||||
values,
|
||||
&twiddle_dbl[3 + fft_layers_pre_transpose..],
|
||||
log_n_elements,
|
||||
fft_layers_post_transpose,
|
||||
);
|
||||
}
|
||||
|
||||
/// Computes partial ifft on `2^log_size` M31 elements.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `values`: Pointer to the entire value array, aligned to 64 bytes.
|
||||
/// - `twiddle_dbl`: The doubles of the twiddle factors for each layer of the the ifft. Layer i
|
||||
/// holds `2^(log_size - 1 - i)` twiddles.
|
||||
/// - `log_size`: The log of the number of number of M31 elements in the array.
|
||||
/// - `fft_layers`: The number of ifft layers to apply, out of log_size.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `log_size` is not at least 5.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// `values` must have the same alignment as [`PackedBaseField`].
|
||||
/// `fft_layers` must be at least 5.
|
||||
pub unsafe fn ifft_lower_with_vecwise(
|
||||
values: *mut u32,
|
||||
twiddle_dbl: &[&[u32]],
|
||||
log_size: usize,
|
||||
fft_layers: usize,
|
||||
) {
|
||||
const VECWISE_FFT_BITS: usize = LOG_N_LANES as usize + 1;
|
||||
assert!(log_size >= VECWISE_FFT_BITS);
|
||||
|
||||
assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2));
|
||||
|
||||
for index_h in 0..1 << (log_size - fft_layers) {
|
||||
ifft_vecwise_loop(values, twiddle_dbl, fft_layers - VECWISE_FFT_BITS, index_h);
|
||||
for layer in (VECWISE_FFT_BITS..fft_layers).step_by(3) {
|
||||
match fft_layers - layer {
|
||||
1 => {
|
||||
ifft1_loop(values, &twiddle_dbl[(layer - 1)..], layer, index_h);
|
||||
}
|
||||
2 => {
|
||||
ifft2_loop(values, &twiddle_dbl[(layer - 1)..], layer, index_h);
|
||||
}
|
||||
_ => {
|
||||
ifft3_loop(
|
||||
values,
|
||||
&twiddle_dbl[(layer - 1)..],
|
||||
fft_layers - layer - 3,
|
||||
layer,
|
||||
index_h,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes partial ifft on `2^log_size` M31 elements, skipping the vecwise layers (lower 4 bits of
|
||||
/// the index).
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `values`: Pointer to the entire value array, aligned to 64 bytes.
|
||||
/// - `twiddle_dbl`: The doubles of the twiddle factors for each layer of the the ifft.
|
||||
/// - `log_size`: The log of the number of number of M31 elements in the array.
|
||||
/// - `fft_layers`: The number of ifft layers to apply, out of `log_size - LOG_N_LANES`.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `log_size` is not at least 4.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// `values` must have the same alignment as [`PackedBaseField`].
|
||||
/// `fft_layers` must be at least 4.
|
||||
pub unsafe fn ifft_lower_without_vecwise(
|
||||
values: *mut u32,
|
||||
twiddle_dbl: &[&[u32]],
|
||||
log_size: usize,
|
||||
fft_layers: usize,
|
||||
) {
|
||||
assert!(log_size >= LOG_N_LANES as usize);
|
||||
|
||||
for index_h in 0..1 << (log_size - fft_layers - LOG_N_LANES as usize) {
|
||||
for layer in (0..fft_layers).step_by(3) {
|
||||
let fixed_layer = layer + LOG_N_LANES as usize;
|
||||
match fft_layers - layer {
|
||||
1 => {
|
||||
ifft1_loop(values, &twiddle_dbl[layer..], fixed_layer, index_h);
|
||||
}
|
||||
2 => {
|
||||
ifft2_loop(values, &twiddle_dbl[layer..], fixed_layer, index_h);
|
||||
}
|
||||
_ => {
|
||||
ifft3_loop(
|
||||
values,
|
||||
&twiddle_dbl[layer..],
|
||||
fft_layers - layer - 3,
|
||||
fixed_layer,
|
||||
index_h,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Runs the first 5 ifft layers across the entire array.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `values`: Pointer to the entire value array, aligned to 64 bytes.
|
||||
/// - `twiddle_dbl`: The doubles of the twiddle factors for each of the 5 ifft layers.
|
||||
/// - `high_bits`: The number of bits this loops needs to run on.
|
||||
/// - `index_h`: The higher part of the index, iterated by the caller.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`].
|
||||
pub unsafe fn ifft_vecwise_loop(
|
||||
values: *mut u32,
|
||||
twiddle_dbl: &[&[u32]],
|
||||
loop_bits: usize,
|
||||
index_h: usize,
|
||||
) {
|
||||
for index_l in 0..1 << loop_bits {
|
||||
let index = (index_h << loop_bits) + index_l;
|
||||
let mut val0 = PackedBaseField::load(values.add(index * 32).cast_const());
|
||||
let mut val1 = PackedBaseField::load(values.add(index * 32 + 16).cast_const());
|
||||
(val0, val1) = vecwise_ibutterflies(
|
||||
val0,
|
||||
val1,
|
||||
std::array::from_fn(|i| *twiddle_dbl[0].get_unchecked(index * 8 + i)),
|
||||
std::array::from_fn(|i| *twiddle_dbl[1].get_unchecked(index * 4 + i)),
|
||||
std::array::from_fn(|i| *twiddle_dbl[2].get_unchecked(index * 2 + i)),
|
||||
);
|
||||
(val0, val1) = simd_ibutterfly(
|
||||
val0,
|
||||
val1,
|
||||
u32x16::splat(*twiddle_dbl[3].get_unchecked(index)),
|
||||
);
|
||||
val0.store(values.add(index * 32));
|
||||
val1.store(values.add(index * 32 + 16));
|
||||
}
|
||||
}
|
||||
|
||||
/// Runs 3 ifft layers across the entire array.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `values`: Pointer to the entire value array, aligned to 64 bytes.
|
||||
/// - `twiddle_dbl`: The doubles of the twiddle factors for each of the 3 ifft layers.
|
||||
/// - `loop_bits`: The number of bits this loops needs to run on.
|
||||
/// - `layer`: The layer number of the first ifft layer to apply. The layers `layer`, `layer + 1`,
|
||||
/// `layer + 2` are applied.
|
||||
/// - `index_h`: The higher part of the index, iterated by the caller.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`].
|
||||
pub unsafe fn ifft3_loop(
|
||||
values: *mut u32,
|
||||
twiddle_dbl: &[&[u32]],
|
||||
loop_bits: usize,
|
||||
layer: usize,
|
||||
index_h: usize,
|
||||
) {
|
||||
for index_l in 0..1 << loop_bits {
|
||||
let index = (index_h << loop_bits) + index_l;
|
||||
let offset = index << (layer + 3);
|
||||
for l in (0..1 << layer).step_by(1 << LOG_N_LANES as usize) {
|
||||
ifft3(
|
||||
values,
|
||||
offset + l,
|
||||
layer,
|
||||
std::array::from_fn(|i| {
|
||||
*twiddle_dbl[0].get_unchecked((index * 4 + i) & (twiddle_dbl[0].len() - 1))
|
||||
}),
|
||||
std::array::from_fn(|i| {
|
||||
*twiddle_dbl[1].get_unchecked((index * 2 + i) & (twiddle_dbl[1].len() - 1))
|
||||
}),
|
||||
std::array::from_fn(|i| {
|
||||
*twiddle_dbl[2].get_unchecked((index + i) & (twiddle_dbl[2].len() - 1))
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Runs 2 ifft layers across the entire array.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `values`: Pointer to the entire value array, aligned to 64 bytes.
|
||||
/// - `twiddle_dbl`: The doubles of the twiddle factors for each of the 2 ifft layers.
|
||||
/// - `loop_bits`: The number of bits this loops needs to run on.
|
||||
/// - `layer`: The layer number of the first ifft layer to apply. The layers `layer`, `layer + 1`
|
||||
/// are applied.
|
||||
/// - `index`: The index, iterated by the caller.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`].
|
||||
unsafe fn ifft2_loop(values: *mut u32, twiddle_dbl: &[&[u32]], layer: usize, index: usize) {
|
||||
let offset = index << (layer + 2);
|
||||
for l in (0..1 << layer).step_by(1 << LOG_N_LANES as usize) {
|
||||
ifft2(
|
||||
values,
|
||||
offset + l,
|
||||
layer,
|
||||
std::array::from_fn(|i| {
|
||||
*twiddle_dbl[0].get_unchecked((index * 2 + i) & (twiddle_dbl[0].len() - 1))
|
||||
}),
|
||||
std::array::from_fn(|i| {
|
||||
*twiddle_dbl[1].get_unchecked((index + i) & (twiddle_dbl[1].len() - 1))
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Runs 1 ifft layer across the entire array.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `values`: Pointer to the entire value array, aligned to 64 bytes.
|
||||
/// - `twiddle_dbl`: The doubles of the twiddle factors for the ifft layer.
|
||||
/// - `layer`: The layer number of the ifft layer to apply.
|
||||
/// - `index_h`: The higher part of the index, iterated by the caller.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`].
|
||||
unsafe fn ifft1_loop(values: *mut u32, twiddle_dbl: &[&[u32]], layer: usize, index: usize) {
|
||||
let offset = index << (layer + 1);
|
||||
for l in (0..1 << layer).step_by(1 << LOG_N_LANES as usize) {
|
||||
ifft1(
|
||||
values,
|
||||
offset + l,
|
||||
layer,
|
||||
std::array::from_fn(|i| {
|
||||
*twiddle_dbl[0].get_unchecked((index + i) & (twiddle_dbl[0].len() - 1))
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the ibutterfly operation for packed M31 elements.
|
||||
///
|
||||
/// Returns `val0 + val1, t (val0 - val1)`. `val0, val1` are packed M31 elements. 16 M31 words at
|
||||
/// each. Each value is assumed to be in unreduced form, [0, P] including P. `twiddle_dbl` holds 16
|
||||
/// values, each is a *double* of a twiddle factor, in unreduced form.
|
||||
pub fn simd_ibutterfly(
|
||||
val0: PackedBaseField,
|
||||
val1: PackedBaseField,
|
||||
twiddle_dbl: u32x16,
|
||||
) -> (PackedBaseField, PackedBaseField) {
|
||||
let r0 = val0 + val1;
|
||||
let r1 = val0 - val1;
|
||||
let prod = mul_twiddle(r1, twiddle_dbl);
|
||||
(r0, prod)
|
||||
}
|
||||
|
||||
/// Runs ifft on 2 vectors of 16 M31 elements.
|
||||
///
|
||||
/// This amounts to 4 butterfly layers, each with 16 butterflies.
|
||||
/// Each of the vectors represents a bit reversed evaluation.
|
||||
/// Each value in a vectors is in unreduced form: [0, P] including P.
|
||||
/// Takes 3 twiddle arrays, one for each layer after the first, holding the double of the
|
||||
/// corresponding twiddle.
|
||||
/// The first layer's twiddles (lower bit of the index) are computed from the second layer's
|
||||
/// twiddles. The second layer takes 8 twiddles.
|
||||
/// The third layer takes 4 twiddles.
|
||||
/// The fourth layer takes 2 twiddles.
|
||||
pub fn vecwise_ibutterflies(
|
||||
mut val0: PackedBaseField,
|
||||
mut val1: PackedBaseField,
|
||||
twiddle1_dbl: [u32; 8],
|
||||
twiddle2_dbl: [u32; 4],
|
||||
twiddle3_dbl: [u32; 2],
|
||||
) -> (PackedBaseField, PackedBaseField) {
|
||||
// TODO(spapini): The permute can be fused with the _mm512_srli_epi64 inside the butterfly.
|
||||
|
||||
// Each `ibutterfly` take 2 512-bit registers, and does 16 butterflies element by element.
|
||||
// We need to permute the 512-bit registers to get the right order for the butterflies.
|
||||
// Denote the index of the 16 M31 elements in register i as i:abcd.
|
||||
// At each layer we apply the following permutation to the index:
|
||||
// i:abcd => d:iabc
|
||||
// This is how it looks like at each iteration.
|
||||
// i:abcd
|
||||
// d:iabc
|
||||
// ifft on d
|
||||
// c:diab
|
||||
// ifft on c
|
||||
// b:cdia
|
||||
// ifft on b
|
||||
// a:bcid
|
||||
// ifft on a
|
||||
// i:abcd
|
||||
|
||||
let (t0, t1) = compute_first_twiddles(twiddle1_dbl.into());
|
||||
|
||||
// Apply the permutation, resulting in indexing d:iabc.
|
||||
(val0, val1) = val0.deinterleave(val1);
|
||||
(val0, val1) = simd_ibutterfly(val0, val1, t0);
|
||||
|
||||
// Apply the permutation, resulting in indexing c:diab.
|
||||
(val0, val1) = val0.deinterleave(val1);
|
||||
(val0, val1) = simd_ibutterfly(val0, val1, t1);
|
||||
|
||||
let t = simd_swizzle!(
|
||||
u32x4::from(twiddle2_dbl),
|
||||
[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
|
||||
);
|
||||
// Apply the permutation, resulting in indexing b:cdia.
|
||||
(val0, val1) = val0.deinterleave(val1);
|
||||
(val0, val1) = simd_ibutterfly(val0, val1, t);
|
||||
|
||||
let t = simd_swizzle!(
|
||||
u32x2::from(twiddle3_dbl),
|
||||
[0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
|
||||
);
|
||||
// Apply the permutation, resulting in indexing a:bcid.
|
||||
(val0, val1) = val0.deinterleave(val1);
|
||||
(val0, val1) = simd_ibutterfly(val0, val1, t);
|
||||
|
||||
// Apply the permutation, resulting in indexing i:abcd.
|
||||
val0.deinterleave(val1)
|
||||
}
|
||||
|
||||
/// Returns the line twiddles (x points) for an ifft on a coset.
|
||||
pub fn get_itwiddle_dbls(mut coset: Coset) -> Vec<Vec<u32>> {
|
||||
let mut res = vec![];
|
||||
for _ in 0..coset.log_size() {
|
||||
res.push(
|
||||
coset
|
||||
.iter()
|
||||
.take(coset.size() / 2)
|
||||
.map(|p| p.x.inverse().0 * 2)
|
||||
.collect_vec(),
|
||||
);
|
||||
bit_reverse(res.last_mut().unwrap());
|
||||
coset = coset.double();
|
||||
}
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
/// Applies 3 ibutterfly layers on 8 vectors of 16 M31 elements.
|
||||
///
|
||||
/// Vectorized over the 16 elements of the vectors.
|
||||
/// Used for radix-8 ifft.
|
||||
/// Each butterfly layer, has 3 SIMD butterflies.
|
||||
/// Total of 12 SIMD butterflies.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `values`: Pointer to the entire value array.
|
||||
/// - `offset`: The offset of the first value in the array.
|
||||
/// - `log_step`: The log of the distance in the array, in M31 elements, between each pair of values
|
||||
/// that need to be transformed. For layer i this is i - 4.
|
||||
/// - `twiddles_dbl0/1/2`: The double of the twiddles for the 3 layers of ibutterflies. Each layer
|
||||
/// has 4/2/1 twiddles.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`].
|
||||
pub unsafe fn ifft3(
|
||||
values: *mut u32,
|
||||
offset: usize,
|
||||
log_step: usize,
|
||||
twiddles_dbl0: [u32; 4],
|
||||
twiddles_dbl1: [u32; 2],
|
||||
twiddles_dbl2: [u32; 1],
|
||||
) {
|
||||
// Load the 8 SIMD vectors from the array.
|
||||
let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const());
|
||||
let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const());
|
||||
let mut val2 = PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const());
|
||||
let mut val3 = PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const());
|
||||
let mut val4 = PackedBaseField::load(values.add(offset + (4 << log_step)).cast_const());
|
||||
let mut val5 = PackedBaseField::load(values.add(offset + (5 << log_step)).cast_const());
|
||||
let mut val6 = PackedBaseField::load(values.add(offset + (6 << log_step)).cast_const());
|
||||
let mut val7 = PackedBaseField::load(values.add(offset + (7 << log_step)).cast_const());
|
||||
|
||||
// Apply the first layer of ibutterflies.
|
||||
(val0, val1) = simd_ibutterfly(val0, val1, u32x16::splat(twiddles_dbl0[0]));
|
||||
(val2, val3) = simd_ibutterfly(val2, val3, u32x16::splat(twiddles_dbl0[1]));
|
||||
(val4, val5) = simd_ibutterfly(val4, val5, u32x16::splat(twiddles_dbl0[2]));
|
||||
(val6, val7) = simd_ibutterfly(val6, val7, u32x16::splat(twiddles_dbl0[3]));
|
||||
|
||||
// Apply the second layer of ibutterflies.
|
||||
(val0, val2) = simd_ibutterfly(val0, val2, u32x16::splat(twiddles_dbl1[0]));
|
||||
(val1, val3) = simd_ibutterfly(val1, val3, u32x16::splat(twiddles_dbl1[0]));
|
||||
(val4, val6) = simd_ibutterfly(val4, val6, u32x16::splat(twiddles_dbl1[1]));
|
||||
(val5, val7) = simd_ibutterfly(val5, val7, u32x16::splat(twiddles_dbl1[1]));
|
||||
|
||||
// Apply the third layer of ibutterflies.
|
||||
(val0, val4) = simd_ibutterfly(val0, val4, u32x16::splat(twiddles_dbl2[0]));
|
||||
(val1, val5) = simd_ibutterfly(val1, val5, u32x16::splat(twiddles_dbl2[0]));
|
||||
(val2, val6) = simd_ibutterfly(val2, val6, u32x16::splat(twiddles_dbl2[0]));
|
||||
(val3, val7) = simd_ibutterfly(val3, val7, u32x16::splat(twiddles_dbl2[0]));
|
||||
|
||||
// Store the 8 SIMD vectors back to the array.
|
||||
val0.store(values.add(offset + (0 << log_step)));
|
||||
val1.store(values.add(offset + (1 << log_step)));
|
||||
val2.store(values.add(offset + (2 << log_step)));
|
||||
val3.store(values.add(offset + (3 << log_step)));
|
||||
val4.store(values.add(offset + (4 << log_step)));
|
||||
val5.store(values.add(offset + (5 << log_step)));
|
||||
val6.store(values.add(offset + (6 << log_step)));
|
||||
val7.store(values.add(offset + (7 << log_step)));
|
||||
}
|
||||
|
||||
/// Applies 2 ibutterfly layers on 4 vectors of 16 M31 elements.
|
||||
///
|
||||
/// Vectorized over the 16 elements of the vectors.
|
||||
/// Used for radix-4 ifft.
|
||||
/// Each ibutterfly layer, has 2 SIMD butterflies.
|
||||
/// Total of 4 SIMD butterflies.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `values`: Pointer to the entire value array.
|
||||
/// - `offset`: The offset of the first value in the array.
|
||||
/// - `log_step`: The log of the distance in the array, in M31 elements, between each pair of values
|
||||
/// that need to be transformed. For layer `i` this is `i - 4`.
|
||||
/// - `twiddles_dbl0/1`: The double of the twiddles for the 2 layers of ibutterflies. Each layer has
|
||||
/// 2/1 twiddles.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`].
|
||||
pub unsafe fn ifft2(
|
||||
values: *mut u32,
|
||||
offset: usize,
|
||||
log_step: usize,
|
||||
twiddles_dbl0: [u32; 2],
|
||||
twiddles_dbl1: [u32; 1],
|
||||
) {
|
||||
// Load the 4 SIMD vectors from the array.
|
||||
let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const());
|
||||
let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const());
|
||||
let mut val2 = PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const());
|
||||
let mut val3 = PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const());
|
||||
|
||||
// Apply the first layer of butterflies.
|
||||
(val0, val1) = simd_ibutterfly(val0, val1, u32x16::splat(twiddles_dbl0[0]));
|
||||
(val2, val3) = simd_ibutterfly(val2, val3, u32x16::splat(twiddles_dbl0[1]));
|
||||
|
||||
// Apply the second layer of butterflies.
|
||||
(val0, val2) = simd_ibutterfly(val0, val2, u32x16::splat(twiddles_dbl1[0]));
|
||||
(val1, val3) = simd_ibutterfly(val1, val3, u32x16::splat(twiddles_dbl1[0]));
|
||||
|
||||
// Store the 4 SIMD vectors back to the array.
|
||||
val0.store(values.add(offset + (0 << log_step)));
|
||||
val1.store(values.add(offset + (1 << log_step)));
|
||||
val2.store(values.add(offset + (2 << log_step)));
|
||||
val3.store(values.add(offset + (3 << log_step)));
|
||||
}
|
||||
|
||||
/// Applies 1 ibutterfly layers on 2 vectors of 16 M31 elements.
|
||||
///
|
||||
/// Vectorized over the 16 elements of the vectors.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `values`: Pointer to the entire value array.
|
||||
/// - `offset`: The offset of the first value in the array.
|
||||
/// - `log_step`: The log of the distance in the array, in M31 elements, between each pair of values
|
||||
/// that need to be transformed. For layer `i` this is `i - 4`.
|
||||
/// - `twiddles_dbl0`: The double of the twiddles for the ibutterfly layer.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`].
|
||||
pub unsafe fn ifft1(values: *mut u32, offset: usize, log_step: usize, twiddles_dbl0: [u32; 1]) {
|
||||
// Load the 2 SIMD vectors from the array.
|
||||
let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const());
|
||||
let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const());
|
||||
|
||||
(val0, val1) = simd_ibutterfly(val0, val1, u32x16::splat(twiddles_dbl0[0]));
|
||||
|
||||
// Store the 2 SIMD vectors back to the array.
|
||||
val0.store(values.add(offset + (0 << log_step)));
|
||||
val1.store(values.add(offset + (1 << log_step)));
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::mem::transmute;
|
||||
use std::simd::u32x16;
|
||||
|
||||
use itertools::Itertools;
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
|
||||
use super::{
|
||||
get_itwiddle_dbls, ifft, ifft3, ifft_lower_with_vecwise, simd_ibutterfly,
|
||||
vecwise_ibutterflies,
|
||||
};
|
||||
use crate::core::backend::cpu::CpuCircleEvaluation;
|
||||
use crate::core::backend::simd::column::BaseColumn;
|
||||
use crate::core::backend::simd::fft::{transpose_vecs, CACHED_FFT_LOG_SIZE};
|
||||
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES, N_LANES};
|
||||
use crate::core::backend::Column;
|
||||
use crate::core::fft::ibutterfly as ground_truth_ibutterfly;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::poly::circle::{CanonicCoset, CircleDomain};
|
||||
|
||||
#[test]
|
||||
fn test_ibutterfly() {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let mut v0: [BaseField; N_LANES] = rng.gen();
|
||||
let mut v1: [BaseField; N_LANES] = rng.gen();
|
||||
let twiddle: [BaseField; N_LANES] = rng.gen();
|
||||
let twiddle_dbl = twiddle.map(|v| v.0 * 2);
|
||||
|
||||
let (r0, r1) = simd_ibutterfly(v0.into(), v1.into(), twiddle_dbl.into());
|
||||
|
||||
let r0 = r0.to_array();
|
||||
let r1 = r1.to_array();
|
||||
for i in 0..N_LANES {
|
||||
ground_truth_ibutterfly(&mut v0[i], &mut v1[i], twiddle[i]);
|
||||
assert_eq!((v0[i], v1[i]), (r0[i], r1[i]), "mismatch at i={i}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ifft3() {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let values = rng.gen::<[BaseField; 8]>().map(PackedBaseField::broadcast);
|
||||
let twiddles0: [BaseField; 4] = rng.gen();
|
||||
let twiddles1: [BaseField; 2] = rng.gen();
|
||||
let twiddles2: [BaseField; 1] = rng.gen();
|
||||
let twiddles0_dbl = twiddles0.map(|v| v.0 * 2);
|
||||
let twiddles1_dbl = twiddles1.map(|v| v.0 * 2);
|
||||
let twiddles2_dbl = twiddles2.map(|v| v.0 * 2);
|
||||
|
||||
let mut res = values;
|
||||
unsafe {
|
||||
ifft3(
|
||||
transmute(res.as_mut_ptr()),
|
||||
0,
|
||||
LOG_N_LANES as usize,
|
||||
twiddles0_dbl,
|
||||
twiddles1_dbl,
|
||||
twiddles2_dbl,
|
||||
)
|
||||
};
|
||||
|
||||
let mut expected = values.map(|v| v.to_array()[0]);
|
||||
for i in 0..8 {
|
||||
let j = i ^ 1;
|
||||
if i > j {
|
||||
continue;
|
||||
}
|
||||
let (mut v0, mut v1) = (expected[i], expected[j]);
|
||||
ground_truth_ibutterfly(&mut v0, &mut v1, twiddles0[i / 2]);
|
||||
(expected[i], expected[j]) = (v0, v1);
|
||||
}
|
||||
for i in 0..8 {
|
||||
let j = i ^ 2;
|
||||
if i > j {
|
||||
continue;
|
||||
}
|
||||
let (mut v0, mut v1) = (expected[i], expected[j]);
|
||||
ground_truth_ibutterfly(&mut v0, &mut v1, twiddles1[i / 4]);
|
||||
(expected[i], expected[j]) = (v0, v1);
|
||||
}
|
||||
for i in 0..8 {
|
||||
let j = i ^ 4;
|
||||
if i > j {
|
||||
continue;
|
||||
}
|
||||
let (mut v0, mut v1) = (expected[i], expected[j]);
|
||||
ground_truth_ibutterfly(&mut v0, &mut v1, twiddles2[0]);
|
||||
(expected[i], expected[j]) = (v0, v1);
|
||||
}
|
||||
for i in 0..8 {
|
||||
assert_eq!(
|
||||
res[i].to_array(),
|
||||
[expected[i]; N_LANES],
|
||||
"mismatch at i={i}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vecwise_ibutterflies() {
|
||||
let domain = CanonicCoset::new(5).circle_domain();
|
||||
let twiddle_dbls = get_itwiddle_dbls(domain.half_coset);
|
||||
assert_eq!(twiddle_dbls.len(), 4);
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let values: [[BaseField; 16]; 2] = rng.gen();
|
||||
|
||||
let res = {
|
||||
let (val0, val1) = vecwise_ibutterflies(
|
||||
values[0].into(),
|
||||
values[1].into(),
|
||||
twiddle_dbls[0].clone().try_into().unwrap(),
|
||||
twiddle_dbls[1].clone().try_into().unwrap(),
|
||||
twiddle_dbls[2].clone().try_into().unwrap(),
|
||||
);
|
||||
let (val0, val1) = simd_ibutterfly(val0, val1, u32x16::splat(twiddle_dbls[3][0]));
|
||||
[val0.to_array(), val1.to_array()].concat()
|
||||
};
|
||||
|
||||
assert_eq!(res, ground_truth_ifft(domain, values.flatten()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ifft_lower_with_vecwise() {
|
||||
for log_size in 5..12 {
|
||||
let domain = CanonicCoset::new(log_size).circle_domain();
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let values = (0..domain.size()).map(|_| rng.gen()).collect_vec();
|
||||
let twiddle_dbls = get_itwiddle_dbls(domain.half_coset);
|
||||
|
||||
let mut res = values.iter().copied().collect::<BaseColumn>();
|
||||
unsafe {
|
||||
ifft_lower_with_vecwise(
|
||||
transmute(res.data.as_mut_ptr()),
|
||||
&twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(),
|
||||
log_size as usize,
|
||||
log_size as usize,
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!(res.to_cpu(), ground_truth_ifft(domain, &values));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ifft_full() {
|
||||
for log_size in CACHED_FFT_LOG_SIZE + 1..CACHED_FFT_LOG_SIZE + 3 {
|
||||
let domain = CanonicCoset::new(log_size).circle_domain();
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let values = (0..domain.size()).map(|_| rng.gen()).collect_vec();
|
||||
let twiddle_dbls = get_itwiddle_dbls(domain.half_coset);
|
||||
|
||||
let mut res = values.iter().copied().collect::<BaseColumn>();
|
||||
unsafe {
|
||||
ifft(
|
||||
transmute(res.data.as_mut_ptr()),
|
||||
&twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(),
|
||||
log_size as usize,
|
||||
);
|
||||
transpose_vecs(transmute(res.data.as_mut_ptr()), log_size as usize - 4);
|
||||
}
|
||||
|
||||
assert_eq!(res.to_cpu(), ground_truth_ifft(domain, &values));
|
||||
}
|
||||
}
|
||||
|
||||
fn ground_truth_ifft(domain: CircleDomain, values: &[BaseField]) -> Vec<BaseField> {
|
||||
let eval = CpuCircleEvaluation::new(domain, values.to_vec());
|
||||
let mut res = eval.interpolate().coeffs;
|
||||
let denorm = BaseField::from(domain.size());
|
||||
res.iter_mut().for_each(|v| *v *= denorm);
|
||||
res
|
||||
}
|
||||
}
|
||||
120
Stwo_wrapper/crates/prover/src/core/backend/simd/fft/mod.rs
Normal file
120
Stwo_wrapper/crates/prover/src/core/backend/simd/fft/mod.rs
Normal file
@ -0,0 +1,120 @@
|
||||
use std::simd::{simd_swizzle, u32x16, u32x8};
|
||||
|
||||
use super::m31::PackedBaseField;
|
||||
use crate::core::fields::m31::P;
|
||||
|
||||
pub mod ifft;
|
||||
pub mod rfft;
|
||||
|
||||
pub const CACHED_FFT_LOG_SIZE: u32 = 16;
|
||||
|
||||
pub const MIN_FFT_LOG_SIZE: u32 = 5;
|
||||
|
||||
// TODO(spapini): FFTs return a redundant representation, that can get the value P. need to reduce
|
||||
// it somewhere.
|
||||
|
||||
/// Transposes the SIMD vectors in the given array.
|
||||
///
|
||||
/// Swaps the bit index abc <-> cba, where |a|=|c| and |b| = 0 or 1, according to the parity of
|
||||
/// `log_n_vecs`.
|
||||
/// When log_n_vecs is odd, transforms the index abc <-> cba, w
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `values`: A mutable pointer to the values that are to be transposed.
|
||||
/// - `log_n_vecs`: The log of the number of SIMD vectors in the `values` array.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// Behavior is undefined if `values` does not have the same alignment as [`u32x16`].
|
||||
pub unsafe fn transpose_vecs(values: *mut u32, log_n_vecs: usize) {
|
||||
let half = log_n_vecs / 2;
|
||||
for b in 0..1 << (log_n_vecs & 1) {
|
||||
for a in 0..1 << half {
|
||||
for c in 0..1 << half {
|
||||
let i = (a << (log_n_vecs - half)) | (b << half) | c;
|
||||
let j = (c << (log_n_vecs - half)) | (b << half) | a;
|
||||
if i >= j {
|
||||
continue;
|
||||
}
|
||||
let val0 = load(values.add(i << 4).cast_const());
|
||||
let val1 = load(values.add(j << 4).cast_const());
|
||||
store(values.add(i << 4), val1);
|
||||
store(values.add(j << 4), val0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the twiddles for the first fft layer from the second, and loads both to SIMD registers.
|
||||
///
|
||||
/// Returns the twiddles for the first layer and the twiddles for the second layer.
|
||||
pub fn compute_first_twiddles(twiddle1_dbl: u32x8) -> (u32x16, u32x16) {
|
||||
// Start by loading the twiddles for the second layer (layer 1):
|
||||
let t1 = simd_swizzle!(
|
||||
twiddle1_dbl,
|
||||
twiddle1_dbl,
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]
|
||||
);
|
||||
|
||||
// The twiddles for layer 0 can be computed from the twiddles for layer 1.
|
||||
// Since the twiddles are bit reversed, we consider the circle domain in bit reversed order.
|
||||
// Each consecutive 4 points in the bit reversed order of a coset form a circle coset of size 4.
|
||||
// A circle coset of size 4 in bit reversed order looks like this:
|
||||
// [(x, y), (-x, -y), (y, -x), (-y, x)]
|
||||
// Note: This is related to the choice of M31_CIRCLE_GEN, and the fact the a quarter rotation
|
||||
// is (0,-1) and not (0,1). (0,1) would yield another relation.
|
||||
// The twiddles for layer 0 are the y coordinates:
|
||||
// [y, -y, -x, x]
|
||||
// The twiddles for layer 1 in bit reversed order are the x coordinates:
|
||||
// [x, y]
|
||||
// Works also for inverse of the twiddles.
|
||||
|
||||
// The twiddles for layer 0 are computed like this:
|
||||
// t0[4i:4i+3] = [t1[2i+1], -t1[2i+1], -t1[2i], t1[2i]]
|
||||
// Xoring a double twiddle with P*2 transforms it to the double of it negation.
|
||||
// Note that this keeps the values as a double of a value in the range [0, P].
|
||||
const P2: u32 = P * 2;
|
||||
const NEGATION_MASK: u32x16 =
|
||||
u32x16::from_array([0, P2, P2, 0, 0, P2, P2, 0, 0, P2, P2, 0, 0, P2, P2, 0]);
|
||||
let t0 = simd_swizzle!(
|
||||
t1,
|
||||
[
|
||||
0b0001, 0b0001, 0b0000, 0b0000, 0b0011, 0b0011, 0b0010, 0b0010, 0b0101, 0b0101, 0b0100,
|
||||
0b0100, 0b0111, 0b0111, 0b0110, 0b0110,
|
||||
]
|
||||
) ^ NEGATION_MASK;
|
||||
(t0, t1)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn load(mem_addr: *const u32) -> u32x16 {
|
||||
std::ptr::read(mem_addr as *const u32x16)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn store(mem_addr: *mut u32, a: u32x16) {
|
||||
std::ptr::write(mem_addr as *mut u32x16, a);
|
||||
}
|
||||
|
||||
/// Computes `v * twiddle`
|
||||
fn mul_twiddle(v: PackedBaseField, twiddle_dbl: u32x16) -> PackedBaseField {
|
||||
// TODO: Come up with a better approach than `cfg`ing on target_feature.
|
||||
// TODO: Ensure all these branches get tested in the CI.
|
||||
cfg_if::cfg_if! {
|
||||
if #[cfg(all(target_feature = "neon", target_arch = "aarch64"))] {
|
||||
// TODO: For architectures that when multiplying require doubling then the twiddles
|
||||
// should be precomputed as double. For other architectures, the twiddle should be
|
||||
// precomputed without doubling.
|
||||
crate::core::backend::simd::m31::_mul_doubled_neon(v, twiddle_dbl)
|
||||
} else if #[cfg(all(target_feature = "simd128", target_arch = "wasm32"))] {
|
||||
crate::core::backend::simd::m31::_mul_doubled_wasm(v, twiddle_dbl)
|
||||
} else if #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] {
|
||||
crate::core::backend::simd::m31::_mul_doubled_avx512(v, twiddle_dbl)
|
||||
} else if #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] {
|
||||
crate::core::backend::simd::m31::_mul_doubled_avx2(v, twiddle_dbl)
|
||||
} else {
|
||||
crate::core::backend::simd::m31::_mul_doubled_simd(v, twiddle_dbl)
|
||||
}
|
||||
}
|
||||
}
|
||||
742
Stwo_wrapper/crates/prover/src/core/backend/simd/fft/rfft.rs
Normal file
742
Stwo_wrapper/crates/prover/src/core/backend/simd/fft/rfft.rs
Normal file
@ -0,0 +1,742 @@
|
||||
//! Regular (forward) fft.
|
||||
|
||||
use std::array;
|
||||
use std::simd::{simd_swizzle, u32x16, u32x2, u32x4, u32x8};
|
||||
|
||||
use itertools::Itertools;
|
||||
|
||||
use super::{
|
||||
compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE,
|
||||
};
|
||||
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
|
||||
use crate::core::circle::Coset;
|
||||
use crate::core::utils::bit_reverse;
|
||||
|
||||
/// Performs a Circle Fast Fourier Transform (CFFT) on the given values.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `src`: A pointer to the values to transform.
|
||||
/// * `dst`: A pointer to the destination array.
|
||||
/// * `twiddle_dbl`: A reference to the doubles of the twiddle factors.
|
||||
/// * `log_n_elements`: The log of the number of elements in the `values` array.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This function will panic if `log_n_elements` is less than `MIN_FFT_LOG_SIZE`.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`].
|
||||
pub unsafe fn fft(src: *const u32, dst: *mut u32, twiddle_dbl: &[&[u32]], log_n_elements: usize) {
|
||||
assert!(log_n_elements >= MIN_FFT_LOG_SIZE as usize);
|
||||
let log_n_vecs = log_n_elements - LOG_N_LANES as usize;
|
||||
if log_n_elements <= CACHED_FFT_LOG_SIZE as usize {
|
||||
fft_lower_with_vecwise(src, dst, twiddle_dbl, log_n_elements, log_n_elements);
|
||||
return;
|
||||
}
|
||||
|
||||
let fft_layers_pre_transpose = log_n_vecs.div_ceil(2);
|
||||
let fft_layers_post_transpose = log_n_vecs / 2;
|
||||
fft_lower_without_vecwise(
|
||||
src,
|
||||
dst,
|
||||
&twiddle_dbl[(3 + fft_layers_pre_transpose)..],
|
||||
log_n_elements,
|
||||
fft_layers_post_transpose,
|
||||
);
|
||||
transpose_vecs(dst, log_n_vecs);
|
||||
fft_lower_with_vecwise(
|
||||
dst,
|
||||
dst,
|
||||
&twiddle_dbl[..3 + fft_layers_pre_transpose],
|
||||
log_n_elements,
|
||||
fft_layers_pre_transpose + LOG_N_LANES as usize,
|
||||
);
|
||||
}
|
||||
|
||||
/// Computes partial fft on `2^log_size` M31 elements.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `src`: A pointer to the values to transform, aligned to 64 bytes.
|
||||
/// - `dst`: A pointer to the destination array, aligned to 64 bytes.
|
||||
/// - `twiddle_dbl`: The doubles of the twiddle factors for each layer of the the fft. Layer `i`
|
||||
/// holds `2^(log_size - 1 - i)` twiddles.
|
||||
/// - `log_size`: The log of the number of number of M31 elements in the array.
|
||||
/// - `fft_layers`: The number of fft layers to apply, out of log_size.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `log_size` is not at least 5.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// `src` and `dst` must have same alignment as [`PackedBaseField`].
|
||||
/// `fft_layers` must be at least 5.
|
||||
pub unsafe fn fft_lower_with_vecwise(
|
||||
src: *const u32,
|
||||
dst: *mut u32,
|
||||
twiddle_dbl: &[&[u32]],
|
||||
log_size: usize,
|
||||
fft_layers: usize,
|
||||
) {
|
||||
const VECWISE_FFT_BITS: usize = LOG_N_LANES as usize + 1;
|
||||
assert!(log_size >= VECWISE_FFT_BITS);
|
||||
|
||||
assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2));
|
||||
|
||||
for index_h in 0..1 << (log_size - fft_layers) {
|
||||
let mut src = src;
|
||||
for layer in (VECWISE_FFT_BITS..fft_layers).step_by(3).rev() {
|
||||
match fft_layers - layer {
|
||||
1 => {
|
||||
fft1_loop(src, dst, &twiddle_dbl[(layer - 1)..], layer, index_h);
|
||||
}
|
||||
2 => {
|
||||
fft2_loop(src, dst, &twiddle_dbl[(layer - 1)..], layer, index_h);
|
||||
}
|
||||
_ => {
|
||||
fft3_loop(
|
||||
src,
|
||||
dst,
|
||||
&twiddle_dbl[(layer - 1)..],
|
||||
fft_layers - layer - 3,
|
||||
layer,
|
||||
index_h,
|
||||
);
|
||||
}
|
||||
}
|
||||
src = dst;
|
||||
}
|
||||
fft_vecwise_loop(
|
||||
src,
|
||||
dst,
|
||||
twiddle_dbl,
|
||||
fft_layers - VECWISE_FFT_BITS,
|
||||
index_h,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes partial fft on `2^log_size` M31 elements, skipping the vecwise layers (lower 4 bits of
|
||||
/// the index).
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `src`: A pointer to the values to transform, aligned to 64 bytes.
|
||||
/// - `dst`: A pointer to the destination array, aligned to 64 bytes.
|
||||
/// - `twiddle_dbl`: The doubles of the twiddle factors for each layer of the the fft.
|
||||
/// - `log_size`: The log of the number of number of M31 elements in the array.
|
||||
/// - `fft_layers`: The number of fft layers to apply, out of log_size - VEC_LOG_SIZE.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `log_size` is not at least 4.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// `src` and `dst` must have same alignment as [`PackedBaseField`].
|
||||
/// `fft_layers` must be at least 4.
|
||||
pub unsafe fn fft_lower_without_vecwise(
|
||||
src: *const u32,
|
||||
dst: *mut u32,
|
||||
twiddle_dbl: &[&[u32]],
|
||||
log_size: usize,
|
||||
fft_layers: usize,
|
||||
) {
|
||||
assert!(log_size >= LOG_N_LANES as usize);
|
||||
|
||||
for index_h in 0..1 << (log_size - fft_layers - LOG_N_LANES as usize) {
|
||||
let mut src = src;
|
||||
for layer in (0..fft_layers).step_by(3).rev() {
|
||||
let fixed_layer = layer + LOG_N_LANES as usize;
|
||||
match fft_layers - layer {
|
||||
1 => {
|
||||
fft1_loop(src, dst, &twiddle_dbl[layer..], fixed_layer, index_h);
|
||||
}
|
||||
2 => {
|
||||
fft2_loop(src, dst, &twiddle_dbl[layer..], fixed_layer, index_h);
|
||||
}
|
||||
_ => {
|
||||
fft3_loop(
|
||||
src,
|
||||
dst,
|
||||
&twiddle_dbl[layer..],
|
||||
fft_layers - layer - 3,
|
||||
fixed_layer,
|
||||
index_h,
|
||||
);
|
||||
}
|
||||
}
|
||||
src = dst;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Runs the last 5 fft layers across the entire array.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `src`: A pointer to the values to transform, aligned to 64 bytes.
|
||||
/// - `dst`: A pointer to the destination array, aligned to 64 bytes.
|
||||
/// - `twiddle_dbl`: The doubles of the twiddle factors for each of the 5 fft layers.
|
||||
/// - `high_bits`: The number of bits this loops needs to run on.
|
||||
/// - `index_h`: The higher part of the index, iterated by the caller.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`].
|
||||
unsafe fn fft_vecwise_loop(
|
||||
src: *const u32,
|
||||
dst: *mut u32,
|
||||
twiddle_dbl: &[&[u32]],
|
||||
loop_bits: usize,
|
||||
index_h: usize,
|
||||
) {
|
||||
for index_l in 0..1 << loop_bits {
|
||||
let index = (index_h << loop_bits) + index_l;
|
||||
let mut val0 = PackedBaseField::load(src.add(index * 32));
|
||||
let mut val1 = PackedBaseField::load(src.add(index * 32 + 16));
|
||||
(val0, val1) = simd_butterfly(
|
||||
val0,
|
||||
val1,
|
||||
u32x16::splat(*twiddle_dbl[3].get_unchecked(index)),
|
||||
);
|
||||
(val0, val1) = vecwise_butterflies(
|
||||
val0,
|
||||
val1,
|
||||
array::from_fn(|i| *twiddle_dbl[0].get_unchecked(index * 8 + i)),
|
||||
array::from_fn(|i| *twiddle_dbl[1].get_unchecked(index * 4 + i)),
|
||||
array::from_fn(|i| *twiddle_dbl[2].get_unchecked(index * 2 + i)),
|
||||
);
|
||||
val0.store(dst.add(index * 32));
|
||||
val1.store(dst.add(index * 32 + 16));
|
||||
}
|
||||
}
|
||||
|
||||
/// Runs 3 fft layers across the entire array.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `src`: A pointer to the values to transform, aligned to 64 bytes.
|
||||
/// - `dst`: A pointer to the destination array, aligned to 64 bytes.
|
||||
/// - `twiddle_dbl`: The doubles of the twiddle factors for each of the 3 fft layers.
|
||||
/// - `loop_bits`: The number of bits this loops needs to run on.
|
||||
/// - `layer`: The layer number of the first fft layer to apply. The layers `layer`, `layer + 1`,
|
||||
/// `layer + 2` are applied.
|
||||
/// - `index_h`: The higher part of the index, iterated by the caller.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`].
|
||||
unsafe fn fft3_loop(
|
||||
src: *const u32,
|
||||
dst: *mut u32,
|
||||
twiddle_dbl: &[&[u32]],
|
||||
loop_bits: usize,
|
||||
layer: usize,
|
||||
index_h: usize,
|
||||
) {
|
||||
for index_l in 0..1 << loop_bits {
|
||||
let index = (index_h << loop_bits) + index_l;
|
||||
let offset = index << (layer + 3);
|
||||
for l in (0..1 << layer).step_by(1 << LOG_N_LANES as usize) {
|
||||
fft3(
|
||||
src,
|
||||
dst,
|
||||
offset + l,
|
||||
layer,
|
||||
array::from_fn(|i| {
|
||||
*twiddle_dbl[0].get_unchecked((index * 4 + i) & (twiddle_dbl[0].len() - 1))
|
||||
}),
|
||||
array::from_fn(|i| {
|
||||
*twiddle_dbl[1].get_unchecked((index * 2 + i) & (twiddle_dbl[1].len() - 1))
|
||||
}),
|
||||
array::from_fn(|i| {
|
||||
*twiddle_dbl[2].get_unchecked((index + i) & (twiddle_dbl[2].len() - 1))
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Runs 2 fft layers across the entire array.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `src`: A pointer to the values to transform, aligned to 64 bytes.
|
||||
/// - `dst`: A pointer to the destination array, aligned to 64 bytes.
|
||||
/// - `twiddle_dbl`: The doubles of the twiddle factors for each of the 2 fft layers.
|
||||
/// - `loop_bits`: The number of bits this loops needs to run on.
|
||||
/// - `layer`: The layer number of the first fft layer to apply. The layers `layer`, `layer + 1` are
|
||||
/// applied.
|
||||
/// - `index`: The index, iterated by the caller.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`].
|
||||
unsafe fn fft2_loop(
|
||||
src: *const u32,
|
||||
dst: *mut u32,
|
||||
twiddle_dbl: &[&[u32]],
|
||||
layer: usize,
|
||||
index: usize,
|
||||
) {
|
||||
let offset = index << (layer + 2);
|
||||
for l in (0..1 << layer).step_by(1 << LOG_N_LANES as usize) {
|
||||
fft2(
|
||||
src,
|
||||
dst,
|
||||
offset + l,
|
||||
layer,
|
||||
array::from_fn(|i| {
|
||||
*twiddle_dbl[0].get_unchecked((index * 2 + i) & (twiddle_dbl[0].len() - 1))
|
||||
}),
|
||||
array::from_fn(|i| {
|
||||
*twiddle_dbl[1].get_unchecked((index + i) & (twiddle_dbl[1].len() - 1))
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Runs 1 fft layer across the entire array.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `src`: A pointer to the values to transform, aligned to 64 bytes.
|
||||
/// - `dst`: A pointer to the destination array, aligned to 64 bytes.
|
||||
/// - `twiddle_dbl`: The doubles of the twiddle factors for the fft layer.
|
||||
/// - `layer`: The layer number of the fft layer to apply.
|
||||
/// - `index_h`: The higher part of the index, iterated by the caller.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`].
|
||||
unsafe fn fft1_loop(
|
||||
src: *const u32,
|
||||
dst: *mut u32,
|
||||
twiddle_dbl: &[&[u32]],
|
||||
layer: usize,
|
||||
index: usize,
|
||||
) {
|
||||
let offset = index << (layer + 1);
|
||||
for l in (0..1 << layer).step_by(1 << LOG_N_LANES as usize) {
|
||||
fft1(
|
||||
src,
|
||||
dst,
|
||||
offset + l,
|
||||
layer,
|
||||
array::from_fn(|i| {
|
||||
*twiddle_dbl[0].get_unchecked((index + i) & (twiddle_dbl[0].len() - 1))
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the butterfly operation for packed M31 elements.
|
||||
///
|
||||
/// Returns `val0 + t val1, val0 - t val1`. `val0, val1` are packed M31 elements. 16 M31 words at
|
||||
/// each. Each value is assumed to be in unreduced form, [0, P] including P. Returned values are in
|
||||
/// unreduced form, [0, P] including P. twiddle_dbl holds 16 values, each is a *double* of a twiddle
|
||||
/// factor, in unreduced form, [0, 2*P].
|
||||
pub fn simd_butterfly(
|
||||
val0: PackedBaseField,
|
||||
val1: PackedBaseField,
|
||||
twiddle_dbl: u32x16,
|
||||
) -> (PackedBaseField, PackedBaseField) {
|
||||
let prod = mul_twiddle(val1, twiddle_dbl);
|
||||
(val0 + prod, val0 - prod)
|
||||
}
|
||||
|
||||
/// Runs fft on 2 vectors of 16 M31 elements.
|
||||
///
|
||||
/// This amounts to 4 butterfly layers, each with 16 butterflies.
|
||||
/// Each of the vectors represents natural ordered polynomial coefficeint.
|
||||
/// Each value in a vectors is in unreduced form: [0, P] including P.
|
||||
/// Takes 4 twiddle arrays, one for each layer, holding the double of the corresponding twiddle.
|
||||
/// The first layer (higher bit of the index) takes 2 twiddles.
|
||||
/// The second layer takes 4 twiddles.
|
||||
/// etc.
|
||||
pub fn vecwise_butterflies(
|
||||
mut val0: PackedBaseField,
|
||||
mut val1: PackedBaseField,
|
||||
twiddle1_dbl: [u32; 8],
|
||||
twiddle2_dbl: [u32; 4],
|
||||
twiddle3_dbl: [u32; 2],
|
||||
) -> (PackedBaseField, PackedBaseField) {
|
||||
// TODO(spapini): Compute twiddle0 from twiddle1.
|
||||
// TODO(spapini): The permute can be fused with the _mm512_srli_epi64 inside the butterfly.
|
||||
// The implementation is the exact reverse of vecwise_ibutterflies().
|
||||
// See the comments in its body for more info.
|
||||
let t = simd_swizzle!(
|
||||
u32x2::from(twiddle3_dbl),
|
||||
[0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
|
||||
);
|
||||
(val0, val1) = val0.interleave(val1);
|
||||
(val0, val1) = simd_butterfly(val0, val1, t);
|
||||
|
||||
let t = simd_swizzle!(
|
||||
u32x4::from(twiddle2_dbl),
|
||||
[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
|
||||
);
|
||||
(val0, val1) = val0.interleave(val1);
|
||||
(val0, val1) = simd_butterfly(val0, val1, t);
|
||||
|
||||
let (t0, t1) = compute_first_twiddles(u32x8::from(twiddle1_dbl));
|
||||
(val0, val1) = val0.interleave(val1);
|
||||
(val0, val1) = simd_butterfly(val0, val1, t1);
|
||||
|
||||
(val0, val1) = val0.interleave(val1);
|
||||
(val0, val1) = simd_butterfly(val0, val1, t0);
|
||||
|
||||
val0.interleave(val1)
|
||||
}
|
||||
|
||||
/// Returns the line twiddles (x points) for an fft on a coset.
|
||||
pub fn get_twiddle_dbls(mut coset: Coset) -> Vec<Vec<u32>> {
|
||||
let mut res = vec![];
|
||||
for _ in 0..coset.log_size() {
|
||||
res.push(
|
||||
coset
|
||||
.iter()
|
||||
.take(coset.size() / 2)
|
||||
.map(|p| p.x.0 * 2)
|
||||
.collect_vec(),
|
||||
);
|
||||
bit_reverse(res.last_mut().unwrap());
|
||||
coset = coset.double();
|
||||
}
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
/// Applies 3 butterfly layers on 8 vectors of 16 M31 elements.
|
||||
///
|
||||
/// Vectorized over the 16 elements of the vectors.
|
||||
/// Used for radix-8 ifft.
|
||||
/// Each butterfly layer, has 3 SIMD butterflies.
|
||||
/// Total of 12 SIMD butterflies.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `src`: A pointer to the values to transform, aligned to 64 bytes.
|
||||
/// - `dst`: A pointer to the destination array, aligned to 64 bytes.
|
||||
/// - `offset`: The offset of the first value in the array.
|
||||
/// - `log_step`: The log of the distance in the array, in M31 elements, between each pair of values
|
||||
/// that need to be transformed. For layer i this is i - 4.
|
||||
/// - `twiddles_dbl0/1/2`: The double of the twiddles for the 3 layers of butterflies. Each layer
|
||||
/// has 4/2/1 twiddles.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`].
|
||||
pub unsafe fn fft3(
|
||||
src: *const u32,
|
||||
dst: *mut u32,
|
||||
offset: usize,
|
||||
log_step: usize,
|
||||
twiddles_dbl0: [u32; 4],
|
||||
twiddles_dbl1: [u32; 2],
|
||||
twiddles_dbl2: [u32; 1],
|
||||
) {
|
||||
// Load the 8 SIMD vectors from the array.
|
||||
let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step)));
|
||||
let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step)));
|
||||
let mut val2 = PackedBaseField::load(src.add(offset + (2 << log_step)));
|
||||
let mut val3 = PackedBaseField::load(src.add(offset + (3 << log_step)));
|
||||
let mut val4 = PackedBaseField::load(src.add(offset + (4 << log_step)));
|
||||
let mut val5 = PackedBaseField::load(src.add(offset + (5 << log_step)));
|
||||
let mut val6 = PackedBaseField::load(src.add(offset + (6 << log_step)));
|
||||
let mut val7 = PackedBaseField::load(src.add(offset + (7 << log_step)));
|
||||
|
||||
// Apply the third layer of butterflies.
|
||||
(val0, val4) = simd_butterfly(val0, val4, u32x16::splat(twiddles_dbl2[0]));
|
||||
(val1, val5) = simd_butterfly(val1, val5, u32x16::splat(twiddles_dbl2[0]));
|
||||
(val2, val6) = simd_butterfly(val2, val6, u32x16::splat(twiddles_dbl2[0]));
|
||||
(val3, val7) = simd_butterfly(val3, val7, u32x16::splat(twiddles_dbl2[0]));
|
||||
|
||||
// Apply the second layer of butterflies.
|
||||
(val0, val2) = simd_butterfly(val0, val2, u32x16::splat(twiddles_dbl1[0]));
|
||||
(val1, val3) = simd_butterfly(val1, val3, u32x16::splat(twiddles_dbl1[0]));
|
||||
(val4, val6) = simd_butterfly(val4, val6, u32x16::splat(twiddles_dbl1[1]));
|
||||
(val5, val7) = simd_butterfly(val5, val7, u32x16::splat(twiddles_dbl1[1]));
|
||||
|
||||
// Apply the first layer of butterflies.
|
||||
(val0, val1) = simd_butterfly(val0, val1, u32x16::splat(twiddles_dbl0[0]));
|
||||
(val2, val3) = simd_butterfly(val2, val3, u32x16::splat(twiddles_dbl0[1]));
|
||||
(val4, val5) = simd_butterfly(val4, val5, u32x16::splat(twiddles_dbl0[2]));
|
||||
(val6, val7) = simd_butterfly(val6, val7, u32x16::splat(twiddles_dbl0[3]));
|
||||
|
||||
// Store the 8 SIMD vectors back to the array.
|
||||
val0.store(dst.add(offset + (0 << log_step)));
|
||||
val1.store(dst.add(offset + (1 << log_step)));
|
||||
val2.store(dst.add(offset + (2 << log_step)));
|
||||
val3.store(dst.add(offset + (3 << log_step)));
|
||||
val4.store(dst.add(offset + (4 << log_step)));
|
||||
val5.store(dst.add(offset + (5 << log_step)));
|
||||
val6.store(dst.add(offset + (6 << log_step)));
|
||||
val7.store(dst.add(offset + (7 << log_step)));
|
||||
}
|
||||
|
||||
/// Applies 2 butterfly layers on 4 vectors of 16 M31 elements.
|
||||
///
|
||||
/// Vectorized over the 16 elements of the vectors.
|
||||
/// Used for radix-4 fft.
|
||||
/// Each butterfly layer, has 2 SIMD butterflies.
|
||||
/// Total of 4 SIMD butterflies.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `src`: A pointer to the values to transform, aligned to 64 bytes.
|
||||
/// - `dst`: A pointer to the destination array, aligned to 64 bytes.
|
||||
/// - `offset`: The offset of the first value in the array.
|
||||
/// - `log_step`: The log of the distance in the array, in M31 elements, between each pair of values
|
||||
/// that need to be transformed. For layer i this is i - 4.
|
||||
/// - `twiddles_dbl0/1`: The double of the twiddles for the 2 layers of butterflies. Each layer has
|
||||
/// 2/1 twiddles.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`].
|
||||
pub unsafe fn fft2(
|
||||
src: *const u32,
|
||||
dst: *mut u32,
|
||||
offset: usize,
|
||||
log_step: usize,
|
||||
twiddles_dbl0: [u32; 2],
|
||||
twiddles_dbl1: [u32; 1],
|
||||
) {
|
||||
// Load the 4 SIMD vectors from the array.
|
||||
let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step)));
|
||||
let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step)));
|
||||
let mut val2 = PackedBaseField::load(src.add(offset + (2 << log_step)));
|
||||
let mut val3 = PackedBaseField::load(src.add(offset + (3 << log_step)));
|
||||
|
||||
// Apply the second layer of butterflies.
|
||||
(val0, val2) = simd_butterfly(val0, val2, u32x16::splat(twiddles_dbl1[0]));
|
||||
(val1, val3) = simd_butterfly(val1, val3, u32x16::splat(twiddles_dbl1[0]));
|
||||
|
||||
// Apply the first layer of butterflies.
|
||||
(val0, val1) = simd_butterfly(val0, val1, u32x16::splat(twiddles_dbl0[0]));
|
||||
(val2, val3) = simd_butterfly(val2, val3, u32x16::splat(twiddles_dbl0[1]));
|
||||
|
||||
// Store the 4 SIMD vectors back to the array.
|
||||
val0.store(dst.add(offset + (0 << log_step)));
|
||||
val1.store(dst.add(offset + (1 << log_step)));
|
||||
val2.store(dst.add(offset + (2 << log_step)));
|
||||
val3.store(dst.add(offset + (3 << log_step)));
|
||||
}
|
||||
|
||||
/// Applies 1 butterfly layers on 2 vectors of 16 M31 elements.
|
||||
///
|
||||
/// Vectorized over the 16 elements of the vectors.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `src`: A pointer to the values to transform, aligned to 64 bytes.
|
||||
/// - `dst`: A pointer to the destination array, aligned to 64 bytes.
|
||||
/// - `offset`: The offset of the first value in the array.
|
||||
/// - `log_step`: The log of the distance in the array, in M31 elements, between each pair of values
|
||||
/// that need to be transformed. For layer i this is i - 4.
|
||||
/// - `twiddles_dbl0`: The double of the twiddles for the butterfly layer.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`].
|
||||
pub unsafe fn fft1(
|
||||
src: *const u32,
|
||||
dst: *mut u32,
|
||||
offset: usize,
|
||||
log_step: usize,
|
||||
twiddles_dbl0: [u32; 1],
|
||||
) {
|
||||
// Load the 2 SIMD vectors from the array.
|
||||
let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step)));
|
||||
let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step)));
|
||||
|
||||
(val0, val1) = simd_butterfly(val0, val1, u32x16::splat(twiddles_dbl0[0]));
|
||||
|
||||
// Store the 2 SIMD vectors back to the array.
|
||||
val0.store(dst.add(offset + (0 << log_step)));
|
||||
val1.store(dst.add(offset + (1 << log_step)));
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::mem::transmute;
|
||||
use std::simd::u32x16;
|
||||
|
||||
use itertools::Itertools;
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
|
||||
use super::{
|
||||
fft, fft3, fft_lower_with_vecwise, get_twiddle_dbls, simd_butterfly, vecwise_butterflies,
|
||||
};
|
||||
use crate::core::backend::cpu::CpuCirclePoly;
|
||||
use crate::core::backend::simd::column::BaseColumn;
|
||||
use crate::core::backend::simd::fft::{transpose_vecs, CACHED_FFT_LOG_SIZE};
|
||||
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES, N_LANES};
|
||||
use crate::core::backend::Column;
|
||||
use crate::core::fft::butterfly as ground_truth_butterfly;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::poly::circle::{CanonicCoset, CircleDomain};
|
||||
|
||||
#[test]
|
||||
fn test_butterfly() {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let mut v0: [BaseField; N_LANES] = rng.gen();
|
||||
let mut v1: [BaseField; N_LANES] = rng.gen();
|
||||
let twiddle: [BaseField; N_LANES] = rng.gen();
|
||||
let twiddle_dbl = twiddle.map(|v| v.0 * 2);
|
||||
|
||||
let (r0, r1) = simd_butterfly(v0.into(), v1.into(), twiddle_dbl.into());
|
||||
|
||||
let r0 = r0.to_array();
|
||||
let r1 = r1.to_array();
|
||||
for i in 0..N_LANES {
|
||||
ground_truth_butterfly(&mut v0[i], &mut v1[i], twiddle[i]);
|
||||
assert_eq!((v0[i], v1[i]), (r0[i], r1[i]), "mismatch at i={i}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fft3() {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let values = rng.gen::<[BaseField; 8]>().map(PackedBaseField::broadcast);
|
||||
let twiddles0: [BaseField; 4] = rng.gen();
|
||||
let twiddles1: [BaseField; 2] = rng.gen();
|
||||
let twiddles2: [BaseField; 1] = rng.gen();
|
||||
let twiddles0_dbl = twiddles0.map(|v| v.0 * 2);
|
||||
let twiddles1_dbl = twiddles1.map(|v| v.0 * 2);
|
||||
let twiddles2_dbl = twiddles2.map(|v| v.0 * 2);
|
||||
|
||||
let mut res = values;
|
||||
unsafe {
|
||||
fft3(
|
||||
transmute(res.as_ptr()),
|
||||
transmute(res.as_mut_ptr()),
|
||||
0,
|
||||
LOG_N_LANES as usize,
|
||||
twiddles0_dbl,
|
||||
twiddles1_dbl,
|
||||
twiddles2_dbl,
|
||||
)
|
||||
};
|
||||
|
||||
let mut expected = values.map(|v| v.to_array()[0]);
|
||||
for i in 0..8 {
|
||||
let j = i ^ 4;
|
||||
if i > j {
|
||||
continue;
|
||||
}
|
||||
let (mut v0, mut v1) = (expected[i], expected[j]);
|
||||
ground_truth_butterfly(&mut v0, &mut v1, twiddles2[0]);
|
||||
(expected[i], expected[j]) = (v0, v1);
|
||||
}
|
||||
for i in 0..8 {
|
||||
let j = i ^ 2;
|
||||
if i > j {
|
||||
continue;
|
||||
}
|
||||
let (mut v0, mut v1) = (expected[i], expected[j]);
|
||||
ground_truth_butterfly(&mut v0, &mut v1, twiddles1[i / 4]);
|
||||
(expected[i], expected[j]) = (v0, v1);
|
||||
}
|
||||
for i in 0..8 {
|
||||
let j = i ^ 1;
|
||||
if i > j {
|
||||
continue;
|
||||
}
|
||||
let (mut v0, mut v1) = (expected[i], expected[j]);
|
||||
ground_truth_butterfly(&mut v0, &mut v1, twiddles0[i / 2]);
|
||||
(expected[i], expected[j]) = (v0, v1);
|
||||
}
|
||||
for i in 0..8 {
|
||||
assert_eq!(
|
||||
res[i].to_array(),
|
||||
[expected[i]; N_LANES],
|
||||
"mismatch at i={i}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vecwise_butterflies() {
|
||||
let domain = CanonicCoset::new(5).circle_domain();
|
||||
let twiddle_dbls = get_twiddle_dbls(domain.half_coset);
|
||||
assert_eq!(twiddle_dbls.len(), 4);
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let values: [[BaseField; 16]; 2] = rng.gen();
|
||||
|
||||
let res = {
|
||||
let (val0, val1) = simd_butterfly(
|
||||
values[0].into(),
|
||||
values[1].into(),
|
||||
u32x16::splat(twiddle_dbls[3][0]),
|
||||
);
|
||||
let (val0, val1) = vecwise_butterflies(
|
||||
val0,
|
||||
val1,
|
||||
twiddle_dbls[0].clone().try_into().unwrap(),
|
||||
twiddle_dbls[1].clone().try_into().unwrap(),
|
||||
twiddle_dbls[2].clone().try_into().unwrap(),
|
||||
);
|
||||
[val0.to_array(), val1.to_array()].concat()
|
||||
};
|
||||
|
||||
assert_eq!(res, ground_truth_fft(domain, values.flatten()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fft_lower() {
|
||||
for log_size in 5..12 {
|
||||
let domain = CanonicCoset::new(log_size).circle_domain();
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let values = (0..domain.size()).map(|_| rng.gen()).collect_vec();
|
||||
let twiddle_dbls = get_twiddle_dbls(domain.half_coset);
|
||||
|
||||
let mut res = values.iter().copied().collect::<BaseColumn>();
|
||||
unsafe {
|
||||
fft_lower_with_vecwise(
|
||||
transmute(res.data.as_ptr()),
|
||||
transmute(res.data.as_mut_ptr()),
|
||||
&twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(),
|
||||
log_size as usize,
|
||||
log_size as usize,
|
||||
)
|
||||
}
|
||||
|
||||
assert_eq!(res.to_cpu(), ground_truth_fft(domain, &values));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fft_full() {
|
||||
for log_size in CACHED_FFT_LOG_SIZE + 1..CACHED_FFT_LOG_SIZE + 3 {
|
||||
let domain = CanonicCoset::new(log_size).circle_domain();
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let values = (0..domain.size()).map(|_| rng.gen()).collect_vec();
|
||||
let twiddle_dbls = get_twiddle_dbls(domain.half_coset);
|
||||
|
||||
let mut res = values.iter().copied().collect::<BaseColumn>();
|
||||
unsafe {
|
||||
transpose_vecs(transmute(res.data.as_mut_ptr()), log_size as usize - 4);
|
||||
fft(
|
||||
transmute(res.data.as_ptr()),
|
||||
transmute(res.data.as_mut_ptr()),
|
||||
&twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(),
|
||||
log_size as usize,
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!(res.to_cpu(), ground_truth_fft(domain, &values));
|
||||
}
|
||||
}
|
||||
|
||||
fn ground_truth_fft(domain: CircleDomain, values: &[BaseField]) -> Vec<BaseField> {
|
||||
let poly = CpuCirclePoly::new(values.to_vec());
|
||||
poly.evaluate(domain).values
|
||||
}
|
||||
}
|
||||
261
Stwo_wrapper/crates/prover/src/core/backend/simd/fri.rs
Normal file
261
Stwo_wrapper/crates/prover/src/core/backend/simd/fri.rs
Normal file
@ -0,0 +1,261 @@
|
||||
use std::array;
|
||||
use std::simd::u32x8;
|
||||
|
||||
use num_traits::Zero;
|
||||
|
||||
use super::m31::{PackedBaseField, LOG_N_LANES, N_LANES};
|
||||
use super::SimdBackend;
|
||||
use crate::core::backend::simd::fft::compute_first_twiddles;
|
||||
use crate::core::backend::simd::fft::ifft::simd_ibutterfly;
|
||||
use crate::core::backend::simd::qm31::PackedSecureField;
|
||||
use crate::core::backend::Column;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::fields::secure_column::SecureColumnByCoords;
|
||||
use crate::core::fri::{self, FriOps};
|
||||
use crate::core::poly::circle::SecureEvaluation;
|
||||
use crate::core::poly::line::LineEvaluation;
|
||||
use crate::core::poly::twiddles::TwiddleTree;
|
||||
use crate::core::poly::utils::domain_line_twiddles_from_tree;
|
||||
use crate::core::poly::BitReversedOrder;
|
||||
|
||||
impl FriOps for SimdBackend {
|
||||
fn fold_line(
|
||||
eval: &LineEvaluation<Self>,
|
||||
alpha: SecureField,
|
||||
twiddles: &TwiddleTree<Self>,
|
||||
) -> LineEvaluation<Self> {
|
||||
let log_size = eval.len().ilog2();
|
||||
if log_size <= LOG_N_LANES {
|
||||
let eval = fri::fold_line(&eval.to_cpu(), alpha);
|
||||
return LineEvaluation::new(eval.domain(), eval.values.into_iter().collect());
|
||||
}
|
||||
|
||||
let domain = eval.domain();
|
||||
let itwiddles = domain_line_twiddles_from_tree(domain, &twiddles.itwiddles)[0];
|
||||
|
||||
let mut folded_values = SecureColumnByCoords::<Self>::zeros(1 << (log_size - 1));
|
||||
|
||||
for vec_index in 0..(1 << (log_size - 1 - LOG_N_LANES)) {
|
||||
let value = unsafe {
|
||||
let twiddle_dbl: [u32; 16] =
|
||||
array::from_fn(|i| *itwiddles.get_unchecked(vec_index * 16 + i));
|
||||
let val0 = eval.values.packed_at(vec_index * 2).into_packed_m31s();
|
||||
let val1 = eval.values.packed_at(vec_index * 2 + 1).into_packed_m31s();
|
||||
let pairs: [_; 4] = array::from_fn(|i| {
|
||||
let (a, b) = val0[i].deinterleave(val1[i]);
|
||||
simd_ibutterfly(a, b, std::mem::transmute(twiddle_dbl))
|
||||
});
|
||||
let val0 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].0));
|
||||
let val1 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].1));
|
||||
val0 + PackedSecureField::broadcast(alpha) * val1
|
||||
};
|
||||
unsafe { folded_values.set_packed(vec_index, value) };
|
||||
}
|
||||
|
||||
LineEvaluation::new(domain.double(), folded_values)
|
||||
}
|
||||
|
||||
fn fold_circle_into_line(
|
||||
dst: &mut LineEvaluation<Self>,
|
||||
src: &SecureEvaluation<Self, BitReversedOrder>,
|
||||
alpha: SecureField,
|
||||
twiddles: &TwiddleTree<Self>,
|
||||
) {
|
||||
let log_size = src.len().ilog2();
|
||||
assert!(log_size > LOG_N_LANES, "Evaluation too small");
|
||||
|
||||
let domain = src.domain;
|
||||
let alpha_sq = alpha * alpha;
|
||||
let itwiddles = domain_line_twiddles_from_tree(domain, &twiddles.itwiddles)[0];
|
||||
|
||||
for vec_index in 0..(1 << (log_size - 1 - LOG_N_LANES)) {
|
||||
let value = unsafe {
|
||||
// The 16 twiddles of the circle domain can be derived from the 8 twiddles of the
|
||||
// next line domain. See `compute_first_twiddles()`.
|
||||
let twiddle_dbl = u32x8::from_array(array::from_fn(|i| {
|
||||
*itwiddles.get_unchecked(vec_index * 8 + i)
|
||||
}));
|
||||
let (t0, _) = compute_first_twiddles(twiddle_dbl);
|
||||
let val0 = src.values.packed_at(vec_index * 2).into_packed_m31s();
|
||||
let val1 = src.values.packed_at(vec_index * 2 + 1).into_packed_m31s();
|
||||
let pairs: [_; 4] = array::from_fn(|i| {
|
||||
let (a, b) = val0[i].deinterleave(val1[i]);
|
||||
simd_ibutterfly(a, b, t0)
|
||||
});
|
||||
let val0 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].0));
|
||||
let val1 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].1));
|
||||
val0 + PackedSecureField::broadcast(alpha) * val1
|
||||
};
|
||||
unsafe {
|
||||
dst.values.set_packed(
|
||||
vec_index,
|
||||
dst.values.packed_at(vec_index) * PackedSecureField::broadcast(alpha_sq)
|
||||
+ value,
|
||||
)
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
fn decompose(
|
||||
eval: &SecureEvaluation<Self, BitReversedOrder>,
|
||||
) -> (SecureEvaluation<Self, BitReversedOrder>, SecureField) {
|
||||
let lambda = decomposition_coefficient(eval);
|
||||
let broadcasted_lambda = PackedSecureField::broadcast(lambda);
|
||||
let mut g_values = SecureColumnByCoords::<Self>::zeros(eval.len());
|
||||
|
||||
let range = eval.len().div_ceil(N_LANES);
|
||||
let half_range = range / 2;
|
||||
for i in 0..half_range {
|
||||
let val = unsafe { eval.packed_at(i) } - broadcasted_lambda;
|
||||
unsafe { g_values.set_packed(i, val) }
|
||||
}
|
||||
for i in half_range..range {
|
||||
let val = unsafe { eval.packed_at(i) } + broadcasted_lambda;
|
||||
unsafe { g_values.set_packed(i, val) }
|
||||
}
|
||||
|
||||
let g = SecureEvaluation::new(eval.domain, g_values);
|
||||
(g, lambda)
|
||||
}
|
||||
}
|
||||
|
||||
/// See [`decomposition_coefficient`].
|
||||
///
|
||||
/// [`decomposition_coefficient`]: crate::core::backend::cpu::CpuBackend::decomposition_coefficient
|
||||
fn decomposition_coefficient(
|
||||
eval: &SecureEvaluation<SimdBackend, BitReversedOrder>,
|
||||
) -> SecureField {
|
||||
let cols = &eval.values.columns;
|
||||
let [mut x_sum, mut y_sum, mut z_sum, mut w_sum] = [PackedBaseField::zero(); 4];
|
||||
|
||||
let range = cols[0].len() / N_LANES;
|
||||
let (half_a, half_b) = (range / 2, range);
|
||||
|
||||
for i in 0..half_a {
|
||||
x_sum += cols[0].data[i];
|
||||
y_sum += cols[1].data[i];
|
||||
z_sum += cols[2].data[i];
|
||||
w_sum += cols[3].data[i];
|
||||
}
|
||||
for i in half_a..half_b {
|
||||
x_sum -= cols[0].data[i];
|
||||
y_sum -= cols[1].data[i];
|
||||
z_sum -= cols[2].data[i];
|
||||
w_sum -= cols[3].data[i];
|
||||
}
|
||||
|
||||
let x = x_sum.pointwise_sum();
|
||||
let y = y_sum.pointwise_sum();
|
||||
let z = z_sum.pointwise_sum();
|
||||
let w = w_sum.pointwise_sum();
|
||||
|
||||
SecureField::from_m31(x, y, z, w) / BaseField::from_u32_unchecked(1 << eval.domain.log_size())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use itertools::Itertools;
|
||||
use num_traits::One;
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
|
||||
use crate::core::backend::simd::column::BaseColumn;
|
||||
use crate::core::backend::simd::SimdBackend;
|
||||
use crate::core::backend::{Column, CpuBackend};
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::fields::secure_column::SecureColumnByCoords;
|
||||
use crate::core::fri::FriOps;
|
||||
use crate::core::poly::circle::{CanonicCoset, CirclePoly, PolyOps, SecureEvaluation};
|
||||
use crate::core::poly::line::{LineDomain, LineEvaluation};
|
||||
use crate::core::poly::BitReversedOrder;
|
||||
use crate::qm31;
|
||||
|
||||
#[test]
|
||||
fn test_fold_line() {
|
||||
const LOG_SIZE: u32 = 7;
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let values = (0..1 << LOG_SIZE).map(|_| rng.gen()).collect_vec();
|
||||
let alpha = qm31!(1, 3, 5, 7);
|
||||
let domain = LineDomain::new(CanonicCoset::new(LOG_SIZE + 1).half_coset());
|
||||
let cpu_fold = CpuBackend::fold_line(
|
||||
&LineEvaluation::new(domain, values.iter().copied().collect()),
|
||||
alpha,
|
||||
&CpuBackend::precompute_twiddles(domain.coset()),
|
||||
);
|
||||
|
||||
let avx_fold = SimdBackend::fold_line(
|
||||
&LineEvaluation::new(domain, values.iter().copied().collect()),
|
||||
alpha,
|
||||
&SimdBackend::precompute_twiddles(domain.coset()),
|
||||
);
|
||||
|
||||
assert_eq!(cpu_fold.values.to_vec(), avx_fold.values.to_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fold_circle_into_line() {
|
||||
const LOG_SIZE: u32 = 7;
|
||||
let values: Vec<SecureField> = (0..(1 << LOG_SIZE))
|
||||
.map(|i| qm31!(4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3))
|
||||
.collect();
|
||||
let alpha = qm31!(1, 3, 5, 7);
|
||||
let circle_domain = CanonicCoset::new(LOG_SIZE).circle_domain();
|
||||
let line_domain = LineDomain::new(circle_domain.half_coset);
|
||||
let mut cpu_fold = LineEvaluation::new(
|
||||
line_domain,
|
||||
SecureColumnByCoords::zeros(1 << (LOG_SIZE - 1)),
|
||||
);
|
||||
CpuBackend::fold_circle_into_line(
|
||||
&mut cpu_fold,
|
||||
&SecureEvaluation::new(circle_domain, values.iter().copied().collect()),
|
||||
alpha,
|
||||
&CpuBackend::precompute_twiddles(line_domain.coset()),
|
||||
);
|
||||
|
||||
let mut simd_fold = LineEvaluation::new(
|
||||
line_domain,
|
||||
SecureColumnByCoords::zeros(1 << (LOG_SIZE - 1)),
|
||||
);
|
||||
SimdBackend::fold_circle_into_line(
|
||||
&mut simd_fold,
|
||||
&SecureEvaluation::new(circle_domain, values.iter().copied().collect()),
|
||||
alpha,
|
||||
&SimdBackend::precompute_twiddles(line_domain.coset()),
|
||||
);
|
||||
|
||||
assert_eq!(cpu_fold.values.to_vec(), simd_fold.values.to_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decomposition_test() {
|
||||
const DOMAIN_LOG_SIZE: u32 = 5;
|
||||
const DOMAIN_LOG_HALF_SIZE: u32 = DOMAIN_LOG_SIZE - 1;
|
||||
let s = CanonicCoset::new(DOMAIN_LOG_SIZE);
|
||||
let domain = s.circle_domain();
|
||||
let mut coeffs = BaseColumn::zeros(1 << DOMAIN_LOG_SIZE);
|
||||
// Polynomial is out of FFT space.
|
||||
coeffs.as_mut_slice()[1 << DOMAIN_LOG_HALF_SIZE] = BaseField::one();
|
||||
let poly = CirclePoly::<SimdBackend>::new(coeffs);
|
||||
let values = poly.evaluate(domain);
|
||||
let avx_column = SecureColumnByCoords::<SimdBackend> {
|
||||
columns: [
|
||||
values.values.clone(),
|
||||
values.values.clone(),
|
||||
values.values.clone(),
|
||||
values.values.clone(),
|
||||
],
|
||||
};
|
||||
let avx_eval = SecureEvaluation::new(domain, avx_column.clone());
|
||||
let cpu_eval =
|
||||
SecureEvaluation::<CpuBackend, BitReversedOrder>::new(domain, avx_eval.to_cpu());
|
||||
let (cpu_g, cpu_lambda) = CpuBackend::decompose(&cpu_eval);
|
||||
let (avx_g, avx_lambda) = SimdBackend::decompose(&avx_eval);
|
||||
|
||||
assert_eq!(avx_lambda, cpu_lambda);
|
||||
for i in 0..1 << DOMAIN_LOG_SIZE {
|
||||
assert_eq!(avx_g.values.at(i), cpu_g.values.at(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
95
Stwo_wrapper/crates/prover/src/core/backend/simd/grind.rs
Normal file
95
Stwo_wrapper/crates/prover/src/core/backend/simd/grind.rs
Normal file
@ -0,0 +1,95 @@
|
||||
use std::simd::cmp::SimdPartialOrd;
|
||||
use std::simd::num::SimdUint;
|
||||
use std::simd::u32x16;
|
||||
|
||||
use bytemuck::cast_slice;
|
||||
#[cfg(feature = "parallel")]
|
||||
use rayon::prelude::*;
|
||||
|
||||
use super::blake2s::compress16;
|
||||
use super::SimdBackend;
|
||||
use crate::core::backend::simd::m31::N_LANES;
|
||||
use crate::core::channel::Blake2sChannel;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::core::channel::{Channel, Poseidon252Channel, PoseidonBLSChannel};
|
||||
use crate::core::proof_of_work::GrindOps;
|
||||
|
||||
// Note: GRIND_LOW_BITS is a cap on how much extra time we need to wait for all threads to finish.
|
||||
const GRIND_LOW_BITS: u32 = 20;
|
||||
const GRIND_HI_BITS: u32 = 64 - GRIND_LOW_BITS;
|
||||
|
||||
impl GrindOps<Blake2sChannel> for SimdBackend {
|
||||
fn grind(channel: &Blake2sChannel, pow_bits: u32) -> u64 {
|
||||
// TODO(spapini): support more than 32 bits.
|
||||
assert!(pow_bits <= 32, "pow_bits > 32 is not supported");
|
||||
let digest = channel.digest();
|
||||
let digest: &[u32] = cast_slice(&digest.0[..]);
|
||||
|
||||
#[cfg(not(feature = "parallel"))]
|
||||
let res = (0..=(1 << GRIND_HI_BITS))
|
||||
.find_map(|hi| grind_blake(digest, hi, pow_bits))
|
||||
.expect("Grind failed to find a solution.");
|
||||
|
||||
#[cfg(feature = "parallel")]
|
||||
let res = (0..=(1 << GRIND_HI_BITS))
|
||||
.into_par_iter()
|
||||
.find_map_any(|hi| grind_blake(digest, hi, pow_bits))
|
||||
.expect("Grind failed to find a solution.");
|
||||
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
fn grind_blake(digest: &[u32], hi: u64, pow_bits: u32) -> Option<u64> {
|
||||
let zero: u32x16 = u32x16::default();
|
||||
let pow_bits = u32x16::splat(pow_bits);
|
||||
|
||||
let state: [u32x16; 8] = std::array::from_fn(|i| u32x16::splat(digest[i]));
|
||||
|
||||
let mut attempt = [zero; 16];
|
||||
attempt[0] = u32x16::splat((hi << GRIND_LOW_BITS) as u32);
|
||||
attempt[0] += u32x16::from(std::array::from_fn(|i| i as u32));
|
||||
attempt[1] = u32x16::splat((hi >> (32 - GRIND_LOW_BITS)) as u32);
|
||||
for low in (0..(1 << GRIND_LOW_BITS)).step_by(N_LANES) {
|
||||
let res = compress16(state, attempt, zero, zero, zero, zero);
|
||||
let success_mask = res[0].trailing_zeros().simd_ge(pow_bits);
|
||||
if success_mask.any() {
|
||||
let i = success_mask.to_array().iter().position(|&x| x).unwrap();
|
||||
return Some((hi << GRIND_LOW_BITS) + low as u64 + i as u64);
|
||||
}
|
||||
attempt[0] += u32x16::splat(N_LANES as u32);
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
// TODO(spapini): This is a naive implementation. Optimize it.
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
impl GrindOps<Poseidon252Channel> for SimdBackend {
|
||||
fn grind(channel: &Poseidon252Channel, pow_bits: u32) -> u64 {
|
||||
let mut nonce = 0;
|
||||
loop {
|
||||
let mut channel = channel.clone();
|
||||
channel.mix_nonce(nonce);
|
||||
if channel.trailing_zeros() >= pow_bits {
|
||||
return nonce;
|
||||
}
|
||||
nonce += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(spapini): This is a naive implementation. Optimize it.
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
impl GrindOps<PoseidonBLSChannel> for SimdBackend {
|
||||
fn grind(channel: &PoseidonBLSChannel, pow_bits: u32) -> u64 {
|
||||
let mut nonce = 0;
|
||||
loop {
|
||||
let mut channel = channel.clone();
|
||||
channel.mix_nonce(nonce);
|
||||
if channel.trailing_zeros() >= pow_bits {
|
||||
return nonce;
|
||||
}
|
||||
nonce += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
684
Stwo_wrapper/crates/prover/src/core/backend/simd/lookups/gkr.rs
Normal file
684
Stwo_wrapper/crates/prover/src/core/backend/simd/lookups/gkr.rs
Normal file
@ -0,0 +1,684 @@
|
||||
use std::iter::zip;
|
||||
|
||||
use num_traits::Zero;
|
||||
|
||||
use crate::core::backend::cpu::lookups::gkr::gen_eq_evals as cpu_gen_eq_evals;
|
||||
use crate::core::backend::simd::column::SecureColumn;
|
||||
use crate::core::backend::simd::m31::{LOG_N_LANES, N_LANES};
|
||||
use crate::core::backend::simd::qm31::PackedSecureField;
|
||||
use crate::core::backend::simd::SimdBackend;
|
||||
use crate::core::backend::{Column, CpuBackend};
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::lookups::gkr_prover::{
|
||||
correct_sum_as_poly_in_first_variable, EqEvals, GkrMultivariatePolyOracle, GkrOps, Layer,
|
||||
};
|
||||
use crate::core::lookups::mle::Mle;
|
||||
use crate::core::lookups::sumcheck::MultivariatePolyOracle;
|
||||
use crate::core::lookups::utils::{Fraction, Reciprocal, UnivariatePoly};
|
||||
|
||||
impl GkrOps for SimdBackend {
|
||||
#[allow(clippy::uninit_vec)]
|
||||
fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Mle<Self, SecureField> {
|
||||
if y.len() < LOG_N_LANES as usize {
|
||||
return Mle::new(cpu_gen_eq_evals(y, v).into_iter().collect());
|
||||
}
|
||||
|
||||
// Start DP with CPU backend to avoid dealing with instances smaller than a SIMD vector.
|
||||
let (y_last_chunk, y_rem) = y.split_last_chunk::<{ LOG_N_LANES as usize }>().unwrap();
|
||||
let initial = SecureColumn::from_iter(cpu_gen_eq_evals(y_last_chunk, v));
|
||||
assert_eq!(initial.len(), N_LANES);
|
||||
|
||||
let packed_len = 1 << y_rem.len();
|
||||
let mut data = initial.data;
|
||||
|
||||
data.reserve(packed_len - data.len());
|
||||
unsafe { data.set_len(packed_len) };
|
||||
|
||||
for (i, &y_j) in y_rem.iter().rev().enumerate() {
|
||||
let packed_y_j = PackedSecureField::broadcast(y_j);
|
||||
|
||||
let (lhs_evals, rhs_evals) = data.split_at_mut(1 << i);
|
||||
|
||||
for (lhs, rhs) in zip(lhs_evals, rhs_evals) {
|
||||
// Equivalent to:
|
||||
// `rhs = eq(1, y_j) * lhs`,
|
||||
// `lhs = eq(0, y_j) * lhs`
|
||||
*rhs = *lhs * packed_y_j;
|
||||
*lhs -= *rhs;
|
||||
}
|
||||
}
|
||||
|
||||
let length = packed_len * N_LANES;
|
||||
Mle::new(SecureColumn { data, length })
|
||||
}
|
||||
|
||||
fn next_layer(layer: &Layer<Self>) -> Layer<Self> {
|
||||
// Offload to CPU backend to avoid dealing with instances smaller than a SIMD vector.
|
||||
if layer.n_variables() as u32 <= LOG_N_LANES {
|
||||
return into_simd_layer(layer.to_cpu().next_layer().unwrap());
|
||||
}
|
||||
|
||||
match layer {
|
||||
Layer::GrandProduct(col) => next_grand_product_layer(col),
|
||||
Layer::LogUpGeneric {
|
||||
numerators,
|
||||
denominators,
|
||||
} => next_logup_generic_layer(numerators, denominators),
|
||||
Layer::LogUpMultiplicities {
|
||||
numerators,
|
||||
denominators,
|
||||
} => next_logup_multiplicities_layer(numerators, denominators),
|
||||
Layer::LogUpSingles { denominators } => next_logup_singles_layer(denominators),
|
||||
}
|
||||
}
|
||||
|
||||
fn sum_as_poly_in_first_variable(
|
||||
h: &GkrMultivariatePolyOracle<'_, Self>,
|
||||
claim: SecureField,
|
||||
) -> UnivariatePoly<SecureField> {
|
||||
let n_variables = h.n_variables();
|
||||
let n_terms = 1 << n_variables.saturating_sub(1);
|
||||
let eq_evals = h.eq_evals.as_ref();
|
||||
// Vector used to generate evaluations of `eq(x, y)` for `x` in the boolean hypercube.
|
||||
let y = eq_evals.y();
|
||||
|
||||
// Offload to CPU backend to avoid dealing with instances smaller than a SIMD vector.
|
||||
if n_terms < N_LANES {
|
||||
return h.to_cpu().sum_as_poly_in_first_variable(claim);
|
||||
}
|
||||
|
||||
let n_packed_terms = n_terms / N_LANES;
|
||||
let packed_lambda = PackedSecureField::broadcast(h.lambda);
|
||||
|
||||
let (mut eval_at_0, mut eval_at_2) = match &h.input_layer {
|
||||
Layer::GrandProduct(col) => eval_grand_product_sum(eq_evals, col, n_packed_terms),
|
||||
Layer::LogUpGeneric {
|
||||
numerators,
|
||||
denominators,
|
||||
} => eval_logup_generic_sum(
|
||||
eq_evals,
|
||||
numerators,
|
||||
denominators,
|
||||
n_packed_terms,
|
||||
packed_lambda,
|
||||
),
|
||||
Layer::LogUpMultiplicities {
|
||||
numerators,
|
||||
denominators,
|
||||
} => eval_logup_multiplicities_sum(
|
||||
eq_evals,
|
||||
numerators,
|
||||
denominators,
|
||||
n_packed_terms,
|
||||
packed_lambda,
|
||||
),
|
||||
Layer::LogUpSingles { denominators } => {
|
||||
eval_logup_singles_sum(eq_evals, denominators, n_packed_terms, packed_lambda)
|
||||
}
|
||||
};
|
||||
|
||||
eval_at_0 *= h.eq_fixed_var_correction;
|
||||
eval_at_2 *= h.eq_fixed_var_correction;
|
||||
correct_sum_as_poly_in_first_variable(eval_at_0, eval_at_2, claim, y, n_variables)
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates the next GKR layer for Grand Product.
|
||||
///
|
||||
/// Assumption: `len(layer) > N_LANES`.
|
||||
fn next_grand_product_layer(layer: &Mle<SimdBackend, SecureField>) -> Layer<SimdBackend> {
|
||||
assert!(layer.len() > N_LANES);
|
||||
let next_layer_len = layer.len() / 2;
|
||||
|
||||
let data = layer
|
||||
.data
|
||||
.array_chunks()
|
||||
.map(|&[a, b]| {
|
||||
let (evens, odds) = a.deinterleave(b);
|
||||
evens * odds
|
||||
})
|
||||
.collect();
|
||||
|
||||
Layer::GrandProduct(Mle::new(SecureColumn {
|
||||
data,
|
||||
length: next_layer_len,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Generates the next GKR layer for LogUp.
|
||||
///
|
||||
/// Assumption: `len(denominators) > N_LANES`.
|
||||
fn next_logup_generic_layer(
|
||||
numerators: &Mle<SimdBackend, SecureField>,
|
||||
denominators: &Mle<SimdBackend, SecureField>,
|
||||
) -> Layer<SimdBackend> {
|
||||
assert!(denominators.len() > N_LANES);
|
||||
assert_eq!(numerators.len(), denominators.len());
|
||||
|
||||
let next_layer_len = denominators.len() / 2;
|
||||
let next_layer_packed_len = next_layer_len / N_LANES;
|
||||
|
||||
let mut next_numerators = Vec::with_capacity(next_layer_packed_len);
|
||||
let mut next_denominators = Vec::with_capacity(next_layer_packed_len);
|
||||
|
||||
for i in 0..next_layer_packed_len {
|
||||
let (n_even, n_odd) = numerators.data[i * 2].deinterleave(numerators.data[i * 2 + 1]);
|
||||
let (d_even, d_odd) = denominators.data[i * 2].deinterleave(denominators.data[i * 2 + 1]);
|
||||
|
||||
let Fraction {
|
||||
numerator,
|
||||
denominator,
|
||||
} = Fraction::new(n_even, d_even) + Fraction::new(n_odd, d_odd);
|
||||
|
||||
next_numerators.push(numerator);
|
||||
next_denominators.push(denominator);
|
||||
}
|
||||
|
||||
let next_numerators = SecureColumn {
|
||||
data: next_numerators,
|
||||
length: next_layer_len,
|
||||
};
|
||||
|
||||
let next_denominators = SecureColumn {
|
||||
data: next_denominators,
|
||||
length: next_layer_len,
|
||||
};
|
||||
|
||||
Layer::LogUpGeneric {
|
||||
numerators: Mle::new(next_numerators),
|
||||
denominators: Mle::new(next_denominators),
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates the next GKR layer for LogUp.
|
||||
///
|
||||
/// Assumption: `len(denominators) > N_LANES`.
|
||||
// TODO(andrew): Code duplication of `next_logup_generic_layer`. Consider unifying these.
|
||||
fn next_logup_multiplicities_layer(
|
||||
numerators: &Mle<SimdBackend, BaseField>,
|
||||
denominators: &Mle<SimdBackend, SecureField>,
|
||||
) -> Layer<SimdBackend> {
|
||||
assert!(denominators.len() > N_LANES);
|
||||
assert_eq!(numerators.len(), denominators.len());
|
||||
|
||||
let next_layer_len = denominators.len() / 2;
|
||||
let next_layer_packed_len = next_layer_len / N_LANES;
|
||||
|
||||
let mut next_numerators = Vec::with_capacity(next_layer_packed_len);
|
||||
let mut next_denominators = Vec::with_capacity(next_layer_packed_len);
|
||||
|
||||
for i in 0..next_layer_packed_len {
|
||||
let (n_even, n_odd) = numerators.data[i * 2].deinterleave(numerators.data[i * 2 + 1]);
|
||||
let (d_even, d_odd) = denominators.data[i * 2].deinterleave(denominators.data[i * 2 + 1]);
|
||||
|
||||
let Fraction {
|
||||
numerator,
|
||||
denominator,
|
||||
} = Fraction::new(n_even, d_even) + Fraction::new(n_odd, d_odd);
|
||||
|
||||
next_numerators.push(numerator);
|
||||
next_denominators.push(denominator);
|
||||
}
|
||||
|
||||
let next_numerators = SecureColumn {
|
||||
data: next_numerators,
|
||||
length: next_layer_len,
|
||||
};
|
||||
|
||||
let next_denominators = SecureColumn {
|
||||
data: next_denominators,
|
||||
length: next_layer_len,
|
||||
};
|
||||
|
||||
Layer::LogUpGeneric {
|
||||
numerators: Mle::new(next_numerators),
|
||||
denominators: Mle::new(next_denominators),
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates the next GKR layer for LogUp.
|
||||
///
|
||||
/// Assumption: `len(denominators) > N_LANES`.
|
||||
fn next_logup_singles_layer(denominators: &Mle<SimdBackend, SecureField>) -> Layer<SimdBackend> {
|
||||
assert!(denominators.len() > N_LANES);
|
||||
|
||||
let next_layer_len = denominators.len() / 2;
|
||||
let next_layer_packed_len = next_layer_len / N_LANES;
|
||||
|
||||
let mut next_numerators = Vec::with_capacity(next_layer_packed_len);
|
||||
let mut next_denominators = Vec::with_capacity(next_layer_packed_len);
|
||||
|
||||
for i in 0..next_layer_packed_len {
|
||||
let (d_even, d_odd) = denominators.data[i * 2].deinterleave(denominators.data[i * 2 + 1]);
|
||||
|
||||
let Fraction {
|
||||
numerator,
|
||||
denominator,
|
||||
} = Reciprocal::new(d_even) + Reciprocal::new(d_odd);
|
||||
|
||||
next_numerators.push(numerator);
|
||||
next_denominators.push(denominator);
|
||||
}
|
||||
|
||||
let next_numerators = SecureColumn {
|
||||
data: next_numerators,
|
||||
length: next_layer_len,
|
||||
};
|
||||
|
||||
let next_denominators = SecureColumn {
|
||||
data: next_denominators,
|
||||
length: next_layer_len,
|
||||
};
|
||||
|
||||
Layer::LogUpGeneric {
|
||||
numerators: Mle::new(next_numerators),
|
||||
denominators: Mle::new(next_denominators),
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluates `sum_x eq(({0}^|r|, 0, x), y) * inp(r, t, x, 0) * inp(r, t, x, 1)` at `t=0` and `t=2`.
|
||||
///
|
||||
/// Output of the form: `(eval_at_0, eval_at_2)`.
|
||||
fn eval_grand_product_sum(
|
||||
eq_evals: &EqEvals<SimdBackend>,
|
||||
col: &Mle<SimdBackend, SecureField>,
|
||||
n_packed_terms: usize,
|
||||
) -> (SecureField, SecureField) {
|
||||
let mut packed_eval_at_0 = PackedSecureField::zero();
|
||||
let mut packed_eval_at_2 = PackedSecureField::zero();
|
||||
|
||||
for i in 0..n_packed_terms {
|
||||
// Input polynomial at points `(r, {0, 1, 2}, bits(i), v, {0, 1})`
|
||||
// for all `v` in `{0, 1}^LOG_N_SIMD_LANES`.
|
||||
let (inp_at_r0iv0, inp_at_r0iv1) = col.data[i * 2].deinterleave(col.data[i * 2 + 1]);
|
||||
let (inp_at_r1iv0, inp_at_r1iv1) =
|
||||
col.data[(n_packed_terms + i) * 2].deinterleave(col.data[(n_packed_terms + i) * 2 + 1]);
|
||||
// Note `inp(r, t, x) = eq(t, 0) * inp(r, 0, x) + eq(t, 1) * inp(r, 1, x)`
|
||||
// => `inp(r, 2, x) = 2 * inp(r, 1, x) - inp(r, 0, x)`
|
||||
let inp_at_r2iv0 = inp_at_r1iv0.double() - inp_at_r0iv0;
|
||||
let inp_at_r2iv1 = inp_at_r1iv1.double() - inp_at_r0iv1;
|
||||
|
||||
// Product polynomial `prod(x) = inp(x, 0) * inp(x, 1)` at points `(r, {0, 2}, bits(i), v)`.
|
||||
// for all `v` in `{0, 1}^LOG_N_SIMD_LANES`.
|
||||
let prod_at_r2iv = inp_at_r2iv0 * inp_at_r2iv1;
|
||||
let prod_at_r0iv = inp_at_r0iv0 * inp_at_r0iv1;
|
||||
|
||||
let eq_eval_at_0iv = eq_evals.data[i];
|
||||
packed_eval_at_0 += eq_eval_at_0iv * prod_at_r0iv;
|
||||
packed_eval_at_2 += eq_eval_at_0iv * prod_at_r2iv;
|
||||
}
|
||||
|
||||
(
|
||||
packed_eval_at_0.pointwise_sum(),
|
||||
packed_eval_at_2.pointwise_sum(),
|
||||
)
|
||||
}
|
||||
|
||||
fn eval_logup_generic_sum(
|
||||
eq_evals: &EqEvals<SimdBackend>,
|
||||
numerators: &Mle<SimdBackend, SecureField>,
|
||||
denominators: &Mle<SimdBackend, SecureField>,
|
||||
n_packed_terms: usize,
|
||||
packed_lambda: PackedSecureField,
|
||||
) -> (SecureField, SecureField) {
|
||||
let mut packed_eval_at_0 = PackedSecureField::zero();
|
||||
let mut packed_eval_at_2 = PackedSecureField::zero();
|
||||
|
||||
let inp_numer = &numerators.data;
|
||||
let inp_denom = &denominators.data;
|
||||
|
||||
for i in 0..n_packed_terms {
|
||||
// Input polynomials at points `(r, {0, 1, 2}, bits(i), v, {0, 1})`
|
||||
// for all `v` in `{0, 1}^LOG_N_SIMD_LANES`.
|
||||
let (inp_numer_at_r0iv0, inp_numer_at_r0iv1) =
|
||||
inp_numer[i * 2].deinterleave(inp_numer[i * 2 + 1]);
|
||||
let (inp_denom_at_r0iv0, inp_denom_at_r0iv1) =
|
||||
inp_denom[i * 2].deinterleave(inp_denom[i * 2 + 1]);
|
||||
let (inp_numer_at_r1iv0, inp_numer_at_r1iv1) = inp_numer[(n_packed_terms + i) * 2]
|
||||
.deinterleave(inp_numer[(n_packed_terms + i) * 2 + 1]);
|
||||
let (inp_denom_at_r1iv0, inp_denom_at_r1iv1) = inp_denom[(n_packed_terms + i) * 2]
|
||||
.deinterleave(inp_denom[(n_packed_terms + i) * 2 + 1]);
|
||||
// Note `inp_denom(r, t, x) = eq(t, 0) * inp_denom(r, 0, x) + eq(t, 1) * inp_denom(r, 1, x)`
|
||||
// => `inp_denom(r, 2, x) = 2 * inp_denom(r, 1, x) - inp_denom(r, 0, x)`
|
||||
let inp_numer_at_r2iv0 = inp_numer_at_r1iv0.double() - inp_numer_at_r0iv0;
|
||||
let inp_numer_at_r2iv1 = inp_numer_at_r1iv1.double() - inp_numer_at_r0iv1;
|
||||
let inp_denom_at_r2iv0 = inp_denom_at_r1iv0.double() - inp_denom_at_r0iv0;
|
||||
let inp_denom_at_r2iv1 = inp_denom_at_r1iv1.double() - inp_denom_at_r0iv1;
|
||||
|
||||
// Fraction addition polynomials:
|
||||
// - `numer(x) = inp_numer(x, 0) * inp_denom(x, 1) + inp_numer(x, 1) * inp_denom(x, 0)`
|
||||
// - `denom(x) = inp_denom(x, 0) * inp_denom(x, 1)`.
|
||||
// at points `(r, {0, 2}, bits(i), v)` for all `v` in `{0, 1}^LOG_N_SIMD_LANES`.
|
||||
let Fraction {
|
||||
numerator: numer_at_r0iv,
|
||||
denominator: denom_at_r0iv,
|
||||
} = Fraction::new(inp_numer_at_r0iv0, inp_denom_at_r0iv0)
|
||||
+ Fraction::new(inp_numer_at_r0iv1, inp_denom_at_r0iv1);
|
||||
let Fraction {
|
||||
numerator: numer_at_r2iv,
|
||||
denominator: denom_at_r2iv,
|
||||
} = Fraction::new(inp_numer_at_r2iv0, inp_denom_at_r2iv0)
|
||||
+ Fraction::new(inp_numer_at_r2iv1, inp_denom_at_r2iv1);
|
||||
|
||||
let eq_eval_at_0iv = eq_evals.data[i];
|
||||
packed_eval_at_0 += eq_eval_at_0iv * (numer_at_r0iv + packed_lambda * denom_at_r0iv);
|
||||
packed_eval_at_2 += eq_eval_at_0iv * (numer_at_r2iv + packed_lambda * denom_at_r2iv);
|
||||
}
|
||||
|
||||
(
|
||||
packed_eval_at_0.pointwise_sum(),
|
||||
packed_eval_at_2.pointwise_sum(),
|
||||
)
|
||||
}
|
||||
|
||||
// TODO(andrew): Code duplication of `eval_logup_generic_sum`. Consider unifying these.
|
||||
fn eval_logup_multiplicities_sum(
|
||||
eq_evals: &EqEvals<SimdBackend>,
|
||||
numerators: &Mle<SimdBackend, BaseField>,
|
||||
denominators: &Mle<SimdBackend, SecureField>,
|
||||
n_packed_terms: usize,
|
||||
packed_lambda: PackedSecureField,
|
||||
) -> (SecureField, SecureField) {
|
||||
let mut packed_eval_at_0 = PackedSecureField::zero();
|
||||
let mut packed_eval_at_2 = PackedSecureField::zero();
|
||||
|
||||
let inp_numer = &numerators.data;
|
||||
let inp_denom = &denominators.data;
|
||||
|
||||
for i in 0..n_packed_terms {
|
||||
// Input polynomials at points `(r, {0, 1, 2}, bits(i), v, {0, 1})`
|
||||
// for all `v` in `{0, 1}^LOG_N_SIMD_LANES`.
|
||||
let (inp_numer_at_r0iv0, inp_numer_at_r0iv1) =
|
||||
inp_numer[i * 2].deinterleave(inp_numer[i * 2 + 1]);
|
||||
let (inp_denom_at_r0iv0, inp_denom_at_r0iv1) =
|
||||
inp_denom[i * 2].deinterleave(inp_denom[i * 2 + 1]);
|
||||
let (inp_numer_at_r1iv0, inp_numer_at_r1iv1) = inp_numer[(n_packed_terms + i) * 2]
|
||||
.deinterleave(inp_numer[(n_packed_terms + i) * 2 + 1]);
|
||||
let (inp_denom_at_r1iv0, inp_denom_at_r1iv1) = inp_denom[(n_packed_terms + i) * 2]
|
||||
.deinterleave(inp_denom[(n_packed_terms + i) * 2 + 1]);
|
||||
// Note `inp_denom(r, t, x) = eq(t, 0) * inp_denom(r, 0, x) + eq(t, 1) * inp_denom(r, 1, x)`
|
||||
// => `inp_denom(r, 2, x) = 2 * inp_denom(r, 1, x) - inp_denom(r, 0, x)`
|
||||
let inp_numer_at_r2iv0 = inp_numer_at_r1iv0.double() - inp_numer_at_r0iv0;
|
||||
let inp_numer_at_r2iv1 = inp_numer_at_r1iv1.double() - inp_numer_at_r0iv1;
|
||||
let inp_denom_at_r2iv0 = inp_denom_at_r1iv0.double() - inp_denom_at_r0iv0;
|
||||
let inp_denom_at_r2iv1 = inp_denom_at_r1iv1.double() - inp_denom_at_r0iv1;
|
||||
|
||||
// Fraction addition polynomials:
|
||||
// - `numer(x) = inp_numer(x, 0) * inp_denom(x, 1) + inp_numer(x, 1) * inp_denom(x, 0)`
|
||||
// - `denom(x) = inp_denom(x, 0) * inp_denom(x, 1)`.
|
||||
// at points `(r, {0, 2}, bits(i), v)` for all `v` in `{0, 1}^LOG_N_SIMD_LANES`.
|
||||
let Fraction {
|
||||
numerator: numer_at_r0iv,
|
||||
denominator: denom_at_r0iv,
|
||||
} = Fraction::new(inp_numer_at_r0iv0, inp_denom_at_r0iv0)
|
||||
+ Fraction::new(inp_numer_at_r0iv1, inp_denom_at_r0iv1);
|
||||
let Fraction {
|
||||
numerator: numer_at_r2iv,
|
||||
denominator: denom_at_r2iv,
|
||||
} = Fraction::new(inp_numer_at_r2iv0, inp_denom_at_r2iv0)
|
||||
+ Fraction::new(inp_numer_at_r2iv1, inp_denom_at_r2iv1);
|
||||
|
||||
let eq_eval_at_0iv = eq_evals.data[i];
|
||||
packed_eval_at_0 += eq_eval_at_0iv * (numer_at_r0iv + packed_lambda * denom_at_r0iv);
|
||||
packed_eval_at_2 += eq_eval_at_0iv * (numer_at_r2iv + packed_lambda * denom_at_r2iv);
|
||||
}
|
||||
|
||||
(
|
||||
packed_eval_at_0.pointwise_sum(),
|
||||
packed_eval_at_2.pointwise_sum(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Evaluates `sum_x eq(({0}^|r|, 0, x), y) * (inp_denom(r, t, x, 1) + inp_denom(r, t, x, 0) +
|
||||
/// lambda * inp_denom(r, t, x, 0) * inp_denom(r, t, x, 1))` at `t=0` and `t=2`.
|
||||
///
|
||||
/// Output of the form: `(eval_at_0, eval_at_2)`.
|
||||
fn eval_logup_singles_sum(
|
||||
eq_evals: &EqEvals<SimdBackend>,
|
||||
denominators: &Mle<SimdBackend, SecureField>,
|
||||
n_packed_terms: usize,
|
||||
packed_lambda: PackedSecureField,
|
||||
) -> (SecureField, SecureField) {
|
||||
let mut packed_eval_at_0 = PackedSecureField::zero();
|
||||
let mut packed_eval_at_2 = PackedSecureField::zero();
|
||||
|
||||
let inp_denom = &denominators.data;
|
||||
|
||||
for i in 0..n_packed_terms {
|
||||
// Input polynomial at points `(r, {0, 1, 2}, bits(i), v, {0, 1})`
|
||||
// for all `v` in `{0, 1}^LOG_N_SIMD_LANES`.
|
||||
let (inp_denom_at_r0iv0, inp_denom_at_r0iv1) =
|
||||
inp_denom[i * 2].deinterleave(inp_denom[i * 2 + 1]);
|
||||
let (inp_denom_at_r1iv0, inp_denom_at_r1iv1) = inp_denom[(n_packed_terms + i) * 2]
|
||||
.deinterleave(inp_denom[(n_packed_terms + i) * 2 + 1]);
|
||||
// Note `inp_denom(r, t, x) = eq(t, 0) * inp_denom(r, 0, x) + eq(t, 1) * inp_denom(r, 1, x)`
|
||||
// => `inp_denom(r, 2, x) = 2 * inp_denom(r, 1, x) - inp_denom(r, 0, x)`
|
||||
let inp_denom_at_r2iv0 = inp_denom_at_r1iv0.double() - inp_denom_at_r0iv0;
|
||||
let inp_denom_at_r2iv1 = inp_denom_at_r1iv1.double() - inp_denom_at_r0iv1;
|
||||
|
||||
// Fraction addition polynomials:
|
||||
// - `numer(x) = inp_denom(x, 1) + inp_denom(x, 0)`
|
||||
// - `denom(x) = inp_denom(x, 0) * inp_denom(x, 1)`.
|
||||
// at points `(r, {0, 2}, bits(i), v)` for all `v` in `{0, 1}^LOG_N_SIMD_LANES`.
|
||||
let Fraction {
|
||||
numerator: numer_at_r0iv,
|
||||
denominator: denom_at_r0iv,
|
||||
} = Reciprocal::new(inp_denom_at_r0iv0) + Reciprocal::new(inp_denom_at_r0iv1);
|
||||
let Fraction {
|
||||
numerator: numer_at_r2iv,
|
||||
denominator: denom_at_r2iv,
|
||||
} = Reciprocal::new(inp_denom_at_r2iv0) + Reciprocal::new(inp_denom_at_r2iv1);
|
||||
|
||||
let eq_eval_at_0iv = eq_evals.data[i];
|
||||
packed_eval_at_0 += eq_eval_at_0iv * (numer_at_r0iv + packed_lambda * denom_at_r0iv);
|
||||
packed_eval_at_2 += eq_eval_at_0iv * (numer_at_r2iv + packed_lambda * denom_at_r2iv);
|
||||
}
|
||||
|
||||
(
|
||||
packed_eval_at_0.pointwise_sum(),
|
||||
packed_eval_at_2.pointwise_sum(),
|
||||
)
|
||||
}
|
||||
|
||||
fn into_simd_layer(cpu_layer: Layer<CpuBackend>) -> Layer<SimdBackend> {
|
||||
match cpu_layer {
|
||||
Layer::GrandProduct(mle) => {
|
||||
Layer::GrandProduct(Mle::new(mle.into_evals().into_iter().collect()))
|
||||
}
|
||||
Layer::LogUpGeneric {
|
||||
numerators,
|
||||
denominators,
|
||||
} => Layer::LogUpGeneric {
|
||||
numerators: Mle::new(numerators.into_evals().into_iter().collect()),
|
||||
denominators: Mle::new(denominators.into_evals().into_iter().collect()),
|
||||
},
|
||||
Layer::LogUpMultiplicities {
|
||||
numerators,
|
||||
denominators,
|
||||
} => Layer::LogUpMultiplicities {
|
||||
numerators: Mle::new(numerators.into_evals().into_iter().collect()),
|
||||
denominators: Mle::new(denominators.into_evals().into_iter().collect()),
|
||||
},
|
||||
Layer::LogUpSingles { denominators } => Layer::LogUpSingles {
|
||||
denominators: Mle::new(denominators.into_evals().into_iter().collect()),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::iter::zip;
|
||||
|
||||
use num_traits::One;
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
|
||||
use crate::core::backend::simd::SimdBackend;
|
||||
use crate::core::backend::{Column, CpuBackend};
|
||||
use crate::core::channel::Channel;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::lookups::gkr_prover::{prove_batch, GkrOps, Layer};
|
||||
use crate::core::lookups::gkr_verifier::{partially_verify_batch, Gate, GkrArtifact, GkrError};
|
||||
use crate::core::lookups::mle::Mle;
|
||||
use crate::core::lookups::utils::Fraction;
|
||||
use crate::core::test_utils::test_channel;
|
||||
|
||||
#[test]
|
||||
fn gen_eq_evals_matches_cpu() {
|
||||
let two = BaseField::from(2).into();
|
||||
let y = [7, 3, 5, 6, 1, 1, 9].map(|v| BaseField::from(v).into());
|
||||
let eq_evals_cpu = CpuBackend::gen_eq_evals(&y, two);
|
||||
|
||||
let eq_evals_simd = SimdBackend::gen_eq_evals(&y, two);
|
||||
|
||||
assert_eq!(eq_evals_simd.to_cpu(), *eq_evals_cpu);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gen_eq_evals_with_small_assignment_matches_cpu() {
|
||||
let two = BaseField::from(2).into();
|
||||
let y = [7, 3, 5].map(|v| BaseField::from(v).into());
|
||||
let eq_evals_cpu = CpuBackend::gen_eq_evals(&y, two);
|
||||
|
||||
let eq_evals_simd = SimdBackend::gen_eq_evals(&y, two);
|
||||
|
||||
assert_eq!(eq_evals_simd.to_cpu(), *eq_evals_cpu);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn grand_product_works() -> Result<(), GkrError> {
|
||||
const N: usize = 1 << 8;
|
||||
let values = test_channel().draw_felts(N);
|
||||
let product = values.iter().product();
|
||||
let col = Mle::<SimdBackend, SecureField>::new(values.into_iter().collect());
|
||||
let input_layer = Layer::GrandProduct(col.clone());
|
||||
let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]);
|
||||
|
||||
let GkrArtifact {
|
||||
ood_point,
|
||||
claims_to_verify_by_instance,
|
||||
n_variables_by_instance: _,
|
||||
} = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut test_channel())?;
|
||||
|
||||
assert_eq!(proof.output_claims_by_instance, [vec![product]]);
|
||||
assert_eq!(
|
||||
claims_to_verify_by_instance,
|
||||
[vec![col.eval_at_point(&ood_point)]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn logup_with_generic_trace_works() -> Result<(), GkrError> {
|
||||
const N: usize = 1 << 8;
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let numerators = (0..N).map(|_| rng.gen()).collect::<Vec<SecureField>>();
|
||||
let denominators = (0..N).map(|_| rng.gen()).collect::<Vec<SecureField>>();
|
||||
let sum = zip(&numerators, &denominators)
|
||||
.map(|(&n, &d)| Fraction::new(n, d))
|
||||
.sum::<Fraction<SecureField, SecureField>>();
|
||||
let numerators = Mle::<SimdBackend, SecureField>::new(numerators.into_iter().collect());
|
||||
let denominators = Mle::<SimdBackend, SecureField>::new(denominators.into_iter().collect());
|
||||
let input_layer = Layer::LogUpGeneric {
|
||||
numerators: numerators.clone(),
|
||||
denominators: denominators.clone(),
|
||||
};
|
||||
let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]);
|
||||
|
||||
let GkrArtifact {
|
||||
ood_point,
|
||||
claims_to_verify_by_instance,
|
||||
n_variables_by_instance: _,
|
||||
} = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?;
|
||||
|
||||
assert_eq!(claims_to_verify_by_instance.len(), 1);
|
||||
assert_eq!(proof.output_claims_by_instance.len(), 1);
|
||||
assert_eq!(
|
||||
claims_to_verify_by_instance[0],
|
||||
[
|
||||
numerators.eval_at_point(&ood_point),
|
||||
denominators.eval_at_point(&ood_point)
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
proof.output_claims_by_instance[0],
|
||||
[sum.numerator, sum.denominator]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn logup_with_multiplicities_trace_works() -> Result<(), GkrError> {
|
||||
const N: usize = 1 << 8;
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let numerators = (0..N).map(|_| rng.gen()).collect::<Vec<BaseField>>();
|
||||
let denominators = (0..N).map(|_| rng.gen()).collect::<Vec<SecureField>>();
|
||||
let sum = zip(&numerators, &denominators)
|
||||
.map(|(&n, &d)| Fraction::new(n.into(), d))
|
||||
.sum::<Fraction<SecureField, SecureField>>();
|
||||
let numerators = Mle::<SimdBackend, BaseField>::new(numerators.into_iter().collect());
|
||||
let denominators = Mle::<SimdBackend, SecureField>::new(denominators.into_iter().collect());
|
||||
let input_layer = Layer::LogUpMultiplicities {
|
||||
numerators: numerators.clone(),
|
||||
denominators: denominators.clone(),
|
||||
};
|
||||
let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]);
|
||||
|
||||
let GkrArtifact {
|
||||
ood_point,
|
||||
claims_to_verify_by_instance,
|
||||
n_variables_by_instance: _,
|
||||
} = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?;
|
||||
|
||||
assert_eq!(claims_to_verify_by_instance.len(), 1);
|
||||
assert_eq!(proof.output_claims_by_instance.len(), 1);
|
||||
assert_eq!(
|
||||
claims_to_verify_by_instance[0],
|
||||
[
|
||||
numerators.eval_at_point(&ood_point),
|
||||
denominators.eval_at_point(&ood_point)
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
proof.output_claims_by_instance[0],
|
||||
[sum.numerator, sum.denominator]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn logup_with_singles_trace_works() -> Result<(), GkrError> {
|
||||
const N: usize = 1 << 8;
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let denominators = (0..N).map(|_| rng.gen()).collect::<Vec<SecureField>>();
|
||||
let sum = denominators
|
||||
.iter()
|
||||
.map(|&d| Fraction::new(SecureField::one(), d))
|
||||
.sum::<Fraction<SecureField, SecureField>>();
|
||||
let denominators = Mle::<SimdBackend, SecureField>::new(denominators.into_iter().collect());
|
||||
let input_layer = Layer::LogUpSingles {
|
||||
denominators: denominators.clone(),
|
||||
};
|
||||
let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]);
|
||||
|
||||
let GkrArtifact {
|
||||
ood_point,
|
||||
claims_to_verify_by_instance,
|
||||
n_variables_by_instance: _,
|
||||
} = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?;
|
||||
|
||||
assert_eq!(claims_to_verify_by_instance.len(), 1);
|
||||
assert_eq!(proof.output_claims_by_instance.len(), 1);
|
||||
assert_eq!(
|
||||
claims_to_verify_by_instance[0],
|
||||
[SecureField::one(), denominators.eval_at_point(&ood_point)]
|
||||
);
|
||||
assert_eq!(
|
||||
proof.output_claims_by_instance[0],
|
||||
[sum.numerator, sum.denominator]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
132
Stwo_wrapper/crates/prover/src/core/backend/simd/lookups/mle.rs
Normal file
132
Stwo_wrapper/crates/prover/src/core/backend/simd/lookups/mle.rs
Normal file
@ -0,0 +1,132 @@
|
||||
use core::ops::Sub;
|
||||
use std::iter::zip;
|
||||
use std::ops::{Add, Mul};
|
||||
|
||||
use crate::core::backend::simd::column::SecureColumn;
|
||||
use crate::core::backend::simd::m31::N_LANES;
|
||||
use crate::core::backend::simd::qm31::PackedSecureField;
|
||||
use crate::core::backend::simd::SimdBackend;
|
||||
use crate::core::backend::{Column, CpuBackend};
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::lookups::mle::{Mle, MleOps};
|
||||
|
||||
impl MleOps<BaseField> for SimdBackend {
|
||||
fn fix_first_variable(
|
||||
mle: Mle<Self, BaseField>,
|
||||
assignment: SecureField,
|
||||
) -> Mle<Self, SecureField> {
|
||||
let midpoint = mle.len() / 2;
|
||||
|
||||
// Use CPU backend to avoid dealing with instances smaller than `PackedSecureField`.
|
||||
if midpoint < N_LANES {
|
||||
let cpu_mle = Mle::<CpuBackend, BaseField>::new(mle.to_cpu());
|
||||
let cpu_res = cpu_mle.fix_first_variable(assignment);
|
||||
return Mle::new(cpu_res.into_evals().into_iter().collect());
|
||||
}
|
||||
|
||||
let packed_assignment = PackedSecureField::broadcast(assignment);
|
||||
let packed_midpoint = midpoint / N_LANES;
|
||||
let (evals_at_0x, evals_at_1x) = mle.data.split_at(packed_midpoint);
|
||||
|
||||
let res = zip(evals_at_0x, evals_at_1x)
|
||||
.enumerate()
|
||||
// MLE at points `({0, 1}, rev(bits(i)), v)` for all `v` in `{0, 1}^LOG_N_SIMD_LANES`.
|
||||
.map(|(_i, (&packed_eval_at_0iv, &packed_eval_at_1iv))| {
|
||||
fold_packed_mle_evals(packed_assignment, packed_eval_at_0iv, packed_eval_at_1iv)
|
||||
})
|
||||
.collect();
|
||||
|
||||
Mle::new(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl MleOps<SecureField> for SimdBackend {
|
||||
fn fix_first_variable(
|
||||
mle: Mle<Self, SecureField>,
|
||||
assignment: SecureField,
|
||||
) -> Mle<Self, SecureField> {
|
||||
let midpoint = mle.len() / 2;
|
||||
|
||||
// Use CPU backend to avoid dealing with instances smaller than `PackedSecureField`.
|
||||
if midpoint < N_LANES {
|
||||
let cpu_mle = Mle::<CpuBackend, SecureField>::new(mle.to_cpu());
|
||||
let cpu_res = cpu_mle.fix_first_variable(assignment);
|
||||
return Mle::new(cpu_res.into_evals().into_iter().collect());
|
||||
}
|
||||
|
||||
let packed_midpoint = midpoint / N_LANES;
|
||||
let packed_assignment = PackedSecureField::broadcast(assignment);
|
||||
let mut packed_evals = mle.into_evals().data;
|
||||
|
||||
for i in 0..packed_midpoint {
|
||||
// MLE at points `({0, 1}, rev(bits(i)), v)` for all `v` in `{0, 1}^LOG_N_SIMD_LANES`.
|
||||
let packed_eval_at_0iv = packed_evals[i];
|
||||
let packed_eval_at_1iv = packed_evals[i + packed_midpoint];
|
||||
packed_evals[i] =
|
||||
fold_packed_mle_evals(packed_assignment, packed_eval_at_0iv, packed_eval_at_1iv);
|
||||
}
|
||||
|
||||
packed_evals.truncate(packed_midpoint);
|
||||
|
||||
let length = packed_evals.len() * N_LANES;
|
||||
let data = packed_evals;
|
||||
|
||||
Mle::new(SecureColumn { data, length })
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes all `eq(0, assignment_i) * eval0_i + eq(1, assignment_i) * eval1_i`.
|
||||
// TODO(andrew): Remove complex trait bounds once we have something like
|
||||
// AbstractField/AbstractExtensionField traits.
|
||||
fn fold_packed_mle_evals<
|
||||
PackedF: Sub<Output = PackedF> + Copy,
|
||||
PackedEF: Mul<PackedF, Output = PackedEF> + Add<PackedF, Output = PackedEF>,
|
||||
>(
|
||||
assignment: PackedEF,
|
||||
eval0: PackedF,
|
||||
eval1: PackedF,
|
||||
) -> PackedEF {
|
||||
assignment * (eval1 - eval0) + eval0
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::core::backend::simd::SimdBackend;
|
||||
use crate::core::backend::{Column, CpuBackend};
|
||||
use crate::core::channel::Channel;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::lookups::mle::Mle;
|
||||
use crate::core::test_utils::test_channel;
|
||||
|
||||
#[test]
|
||||
fn fix_first_variable_with_secure_field_mle_matches_cpu() {
|
||||
const N_VARIABLES: u32 = 8;
|
||||
let values = test_channel().draw_felts(1 << N_VARIABLES);
|
||||
let mle_simd = Mle::<SimdBackend, SecureField>::new(values.iter().copied().collect());
|
||||
let mle_cpu = Mle::<CpuBackend, SecureField>::new(values);
|
||||
let random_assignment = SecureField::from_u32_unchecked(7, 12, 3, 2);
|
||||
let mle_fixed_cpu = mle_cpu.fix_first_variable(random_assignment);
|
||||
|
||||
let mle_fixed_simd = mle_simd.fix_first_variable(random_assignment);
|
||||
|
||||
assert_eq!(mle_fixed_simd.into_evals().to_cpu(), *mle_fixed_cpu)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fix_first_variable_with_base_field_mle_matches_cpu() {
|
||||
const N_VARIABLES: u32 = 8;
|
||||
let values = (0..1 << N_VARIABLES).map(BaseField::from).collect_vec();
|
||||
let mle_simd = Mle::<SimdBackend, BaseField>::new(values.iter().copied().collect());
|
||||
let mle_cpu = Mle::<CpuBackend, BaseField>::new(values);
|
||||
let random_assignment = SecureField::from_u32_unchecked(7, 12, 3, 2);
|
||||
let mle_fixed_cpu = mle_cpu.fix_first_variable(random_assignment);
|
||||
|
||||
let mle_fixed_simd = mle_simd.fix_first_variable(random_assignment);
|
||||
|
||||
assert_eq!(mle_fixed_simd.into_evals().to_cpu(), *mle_fixed_cpu)
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,2 @@
|
||||
mod gkr;
|
||||
mod mle;
|
||||
666
Stwo_wrapper/crates/prover/src/core/backend/simd/m31.rs
Normal file
666
Stwo_wrapper/crates/prover/src/core/backend/simd/m31.rs
Normal file
@ -0,0 +1,666 @@
|
||||
use std::iter::Sum;
|
||||
use std::mem::transmute;
|
||||
use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
|
||||
use std::ptr;
|
||||
use std::simd::cmp::SimdOrd;
|
||||
use std::simd::{u32x16, Simd, Swizzle};
|
||||
|
||||
use bytemuck::{Pod, Zeroable};
|
||||
use num_traits::{One, Zero};
|
||||
use rand::distributions::{Distribution, Standard};
|
||||
|
||||
use super::qm31::PackedQM31;
|
||||
use crate::core::backend::simd::utils::{InterleaveEvens, InterleaveOdds};
|
||||
use crate::core::fields::m31::{pow2147483645, BaseField, M31, P};
|
||||
use crate::core::fields::qm31::QM31;
|
||||
use crate::core::fields::FieldExpOps;
|
||||
|
||||
pub const LOG_N_LANES: u32 = 4;
|
||||
|
||||
pub const N_LANES: usize = 1 << LOG_N_LANES;
|
||||
|
||||
pub const MODULUS: Simd<u32, N_LANES> = Simd::from_array([P; N_LANES]);
|
||||
|
||||
pub type PackedBaseField = PackedM31;
|
||||
|
||||
/// Holds a vector of unreduced [`M31`] elements in the range `[0, P]`.
|
||||
///
|
||||
/// Implemented with [`std::simd`] to support multiple targets (avx512, neon, wasm etc.).
|
||||
// TODO: Remove `pub` visibility
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
#[repr(transparent)]
|
||||
pub struct PackedM31(Simd<u32, N_LANES>);
|
||||
|
||||
impl PackedM31 {
|
||||
/// Constructs a new instance with all vector elements set to `value`.
|
||||
pub fn broadcast(M31(value): M31) -> Self {
|
||||
Self(Simd::splat(value))
|
||||
}
|
||||
|
||||
pub fn from_array(values: [M31; N_LANES]) -> PackedM31 {
|
||||
Self(Simd::from_array(values.map(|M31(v)| v)))
|
||||
}
|
||||
|
||||
pub fn to_array(self) -> [M31; N_LANES] {
|
||||
self.reduce().0.to_array().map(M31)
|
||||
}
|
||||
|
||||
/// Reduces each element of the vector to the range `[0, P)`.
|
||||
fn reduce(self) -> PackedM31 {
|
||||
Self(Simd::simd_min(self.0, self.0 - MODULUS))
|
||||
}
|
||||
|
||||
/// Interleaves two vectors.
|
||||
pub fn interleave(self, other: Self) -> (Self, Self) {
|
||||
let (a, b) = self.0.interleave(other.0);
|
||||
(Self(a), Self(b))
|
||||
}
|
||||
|
||||
/// Deinterleaves two vectors.
|
||||
pub fn deinterleave(self, other: Self) -> (Self, Self) {
|
||||
let (a, b) = self.0.deinterleave(other.0);
|
||||
(Self(a), Self(b))
|
||||
}
|
||||
|
||||
/// Reverses the order of the elements in the vector.
|
||||
pub fn reverse(self) -> Self {
|
||||
Self(self.0.reverse())
|
||||
}
|
||||
|
||||
/// Sums all the elements in the vector.
|
||||
pub fn pointwise_sum(self) -> M31 {
|
||||
self.to_array().into_iter().sum()
|
||||
}
|
||||
|
||||
/// Doubles each element in the vector.
|
||||
pub fn double(self) -> Self {
|
||||
// TODO: Make more optimal.
|
||||
self + self
|
||||
}
|
||||
|
||||
pub fn into_simd(self) -> Simd<u32, N_LANES> {
|
||||
self.0
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// Vector elements must be in the range `[0, P]`.
|
||||
pub unsafe fn from_simd_unchecked(v: Simd<u32, N_LANES>) -> Self {
|
||||
Self(v)
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// Behavior is undefined if the pointer does not have the same alignment as
|
||||
/// [`PackedM31`]. The loaded `u32` values must be in the range `[0, P]`.
|
||||
pub unsafe fn load(mem_addr: *const u32) -> Self {
|
||||
Self(ptr::read(mem_addr as *const u32x16))
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// Behavior is undefined if the pointer does not have the same alignment as
|
||||
/// [`PackedM31`].
|
||||
pub unsafe fn store(self, dst: *mut u32) {
|
||||
ptr::write(dst as *mut u32x16, self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Add for PackedM31 {
|
||||
type Output = Self;
|
||||
|
||||
#[inline(always)]
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
// Add word by word. Each word is in the range [0, 2P].
|
||||
let c = self.0 + rhs.0;
|
||||
// Apply min(c, c-P) to each word.
|
||||
// When c in [P,2P], then c-P in [0,P] which is always less than [P,2P].
|
||||
// When c in [0,P-1], then c-P in [2^32-P,2^32-1] which is always greater than [0,P-1].
|
||||
Self(Simd::simd_min(c, c - MODULUS))
|
||||
}
|
||||
}
|
||||
|
||||
impl AddAssign for PackedM31 {
|
||||
#[inline(always)]
|
||||
fn add_assign(&mut self, rhs: Self) {
|
||||
*self = *self + rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl AddAssign<M31> for PackedM31 {
|
||||
#[inline(always)]
|
||||
fn add_assign(&mut self, rhs: M31) {
|
||||
*self = *self + PackedM31::broadcast(rhs);
|
||||
}
|
||||
}
|
||||
|
||||
impl Mul for PackedM31 {
|
||||
type Output = Self;
|
||||
|
||||
#[inline(always)]
|
||||
fn mul(self, rhs: Self) -> Self {
|
||||
// TODO: Come up with a better approach than `cfg`ing on target_feature.
|
||||
// TODO: Ensure all these branches get tested in the CI.
|
||||
cfg_if::cfg_if! {
|
||||
if #[cfg(all(target_feature = "neon", target_arch = "aarch64"))] {
|
||||
_mul_neon(self, rhs)
|
||||
} else if #[cfg(all(target_feature = "simd128", target_arch = "wasm32"))] {
|
||||
_mul_wasm(self, rhs)
|
||||
} else if #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] {
|
||||
_mul_avx512(self, rhs)
|
||||
} else if #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] {
|
||||
_mul_avx2(self, rhs)
|
||||
} else {
|
||||
_mul_simd(self, rhs)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Mul<M31> for PackedM31 {
|
||||
type Output = Self;
|
||||
|
||||
#[inline(always)]
|
||||
fn mul(self, rhs: M31) -> Self::Output {
|
||||
self * PackedM31::broadcast(rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl Add<M31> for PackedM31 {
|
||||
type Output = PackedM31;
|
||||
|
||||
#[inline(always)]
|
||||
fn add(self, rhs: M31) -> Self::Output {
|
||||
PackedM31::broadcast(rhs) + self
|
||||
}
|
||||
}
|
||||
|
||||
impl Add<QM31> for PackedM31 {
|
||||
type Output = PackedQM31;
|
||||
|
||||
#[inline(always)]
|
||||
fn add(self, rhs: QM31) -> Self::Output {
|
||||
PackedQM31::broadcast(rhs) + self
|
||||
}
|
||||
}
|
||||
|
||||
impl Mul<QM31> for PackedM31 {
|
||||
type Output = PackedQM31;
|
||||
|
||||
#[inline(always)]
|
||||
fn mul(self, rhs: QM31) -> Self::Output {
|
||||
PackedQM31::broadcast(rhs) * self
|
||||
}
|
||||
}
|
||||
|
||||
impl MulAssign for PackedM31 {
|
||||
#[inline(always)]
|
||||
fn mul_assign(&mut self, rhs: Self) {
|
||||
*self = *self * rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl Neg for PackedM31 {
|
||||
type Output = Self;
|
||||
|
||||
#[inline(always)]
|
||||
fn neg(self) -> Self::Output {
|
||||
Self(MODULUS - self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Sub for PackedM31 {
|
||||
type Output = Self;
|
||||
|
||||
#[inline(always)]
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
// Subtract word by word. Each word is in the range [-P, P].
|
||||
let c = self.0 - rhs.0;
|
||||
// Apply min(c, c+P) to each word.
|
||||
// When c in [0,P], then c+P in [P,2P] which is always greater than [0,P].
|
||||
// When c in [2^32-P,2^32-1], then c+P in [0,P-1] which is always less than
|
||||
// [2^32-P,2^32-1].
|
||||
Self(Simd::simd_min(c + MODULUS, c))
|
||||
}
|
||||
}
|
||||
|
||||
impl SubAssign for PackedM31 {
|
||||
#[inline(always)]
|
||||
fn sub_assign(&mut self, rhs: Self) {
|
||||
*self = *self - rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl Zero for PackedM31 {
|
||||
fn zero() -> Self {
|
||||
Self(Simd::from_array([0; N_LANES]))
|
||||
}
|
||||
|
||||
fn is_zero(&self) -> bool {
|
||||
self.to_array().iter().all(M31::is_zero)
|
||||
}
|
||||
}
|
||||
|
||||
impl One for PackedM31 {
|
||||
fn one() -> Self {
|
||||
Self(Simd::<u32, N_LANES>::from_array([1; N_LANES]))
|
||||
}
|
||||
}
|
||||
|
||||
impl FieldExpOps for PackedM31 {
|
||||
fn inverse(&self) -> Self {
|
||||
assert!(!self.is_zero(), "0 has no inverse");
|
||||
pow2147483645(*self)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl Pod for PackedM31 {}
|
||||
|
||||
unsafe impl Zeroable for PackedM31 {
|
||||
fn zeroed() -> Self {
|
||||
unsafe { core::mem::zeroed() }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[BaseField; N_LANES]> for PackedM31 {
|
||||
fn from(v: [BaseField; N_LANES]) -> Self {
|
||||
Self::from_array(v)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<BaseField> for PackedM31 {
|
||||
fn from(v: BaseField) -> Self {
|
||||
Self::broadcast(v)
|
||||
}
|
||||
}
|
||||
|
||||
impl Distribution<PackedM31> for Standard {
|
||||
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> PackedM31 {
|
||||
PackedM31::from_array(rng.gen())
|
||||
}
|
||||
}
|
||||
|
||||
impl Sum for PackedM31 {
|
||||
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
|
||||
iter.fold(Self::zero(), Add::add)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `a * b`.
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
pub(crate) fn _mul_neon(a: PackedM31, b: PackedM31) -> PackedM31 {
|
||||
use core::arch::aarch64::{int32x2_t, vqdmull_s32};
|
||||
use std::simd::u32x4;
|
||||
|
||||
let [a0, a1, a2, a3, a4, a5, a6, a7]: [int32x2_t; 8] = unsafe { transmute(a) };
|
||||
let [b0, b1, b2, b3, b4, b5, b6, b7]: [int32x2_t; 8] = unsafe { transmute(b) };
|
||||
|
||||
// Each c_i contains |0|prod_lo|prod_hi|0|0|prod_lo|prod_hi|0|
|
||||
let c0: u32x4 = unsafe { transmute(vqdmull_s32(a0, b0)) };
|
||||
let c1: u32x4 = unsafe { transmute(vqdmull_s32(a1, b1)) };
|
||||
let c2: u32x4 = unsafe { transmute(vqdmull_s32(a2, b2)) };
|
||||
let c3: u32x4 = unsafe { transmute(vqdmull_s32(a3, b3)) };
|
||||
let c4: u32x4 = unsafe { transmute(vqdmull_s32(a4, b4)) };
|
||||
let c5: u32x4 = unsafe { transmute(vqdmull_s32(a5, b5)) };
|
||||
let c6: u32x4 = unsafe { transmute(vqdmull_s32(a6, b6)) };
|
||||
let c7: u32x4 = unsafe { transmute(vqdmull_s32(a7, b7)) };
|
||||
|
||||
// *_lo contain `|prod_lo|0|prod_lo|0|prod_lo0|0|prod_lo|0|`.
|
||||
// *_hi contain `|0|prod_hi|0|prod_hi|0|prod_hi|0|prod_hi|`.
|
||||
let (mut c0_c1_lo, c0_c1_hi) = c0.deinterleave(c1);
|
||||
let (mut c2_c3_lo, c2_c3_hi) = c2.deinterleave(c3);
|
||||
let (mut c4_c5_lo, c4_c5_hi) = c4.deinterleave(c5);
|
||||
let (mut c6_c7_lo, c6_c7_hi) = c6.deinterleave(c7);
|
||||
|
||||
// *_lo contain `|0|prod_lo|0|prod_lo|0|prod_lo|0|prod_lo|`.
|
||||
c0_c1_lo >>= 1;
|
||||
c2_c3_lo >>= 1;
|
||||
c4_c5_lo >>= 1;
|
||||
c6_c7_lo >>= 1;
|
||||
|
||||
let lo: PackedM31 = unsafe { transmute([c0_c1_lo, c2_c3_lo, c4_c5_lo, c6_c7_lo]) };
|
||||
let hi: PackedM31 = unsafe { transmute([c0_c1_hi, c2_c3_hi, c4_c5_hi, c6_c7_hi]) };
|
||||
|
||||
lo + hi
|
||||
}
|
||||
|
||||
/// Returns `a * b`.
|
||||
///
|
||||
/// `b_double` should be in the range `[0, 2P]`.
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
pub(crate) fn _mul_doubled_neon(a: PackedM31, b_double: u32x16) -> PackedM31 {
|
||||
use core::arch::aarch64::{uint32x2_t, vmull_u32};
|
||||
use std::simd::u32x4;
|
||||
|
||||
let [a0, a1, a2, a3, a4, a5, a6, a7]: [uint32x2_t; 8] = unsafe { transmute(a) };
|
||||
let [b0, b1, b2, b3, b4, b5, b6, b7]: [uint32x2_t; 8] = unsafe { transmute(b_double) };
|
||||
|
||||
// Each c_i contains |0|prod_lo|prod_hi|0|0|prod_lo|prod_hi|0|
|
||||
let c0: u32x4 = unsafe { transmute(vmull_u32(a0, b0)) };
|
||||
let c1: u32x4 = unsafe { transmute(vmull_u32(a1, b1)) };
|
||||
let c2: u32x4 = unsafe { transmute(vmull_u32(a2, b2)) };
|
||||
let c3: u32x4 = unsafe { transmute(vmull_u32(a3, b3)) };
|
||||
let c4: u32x4 = unsafe { transmute(vmull_u32(a4, b4)) };
|
||||
let c5: u32x4 = unsafe { transmute(vmull_u32(a5, b5)) };
|
||||
let c6: u32x4 = unsafe { transmute(vmull_u32(a6, b6)) };
|
||||
let c7: u32x4 = unsafe { transmute(vmull_u32(a7, b7)) };
|
||||
|
||||
// *_lo contain `|prod_lo|0|prod_lo|0|prod_lo0|0|prod_lo|0|`.
|
||||
// *_hi contain `|0|prod_hi|0|prod_hi|0|prod_hi|0|prod_hi|`.
|
||||
let (mut c0_c1_lo, c0_c1_hi) = c0.deinterleave(c1);
|
||||
let (mut c2_c3_lo, c2_c3_hi) = c2.deinterleave(c3);
|
||||
let (mut c4_c5_lo, c4_c5_hi) = c4.deinterleave(c5);
|
||||
let (mut c6_c7_lo, c6_c7_hi) = c6.deinterleave(c7);
|
||||
|
||||
// *_lo contain `|0|prod_lo|0|prod_lo|0|prod_lo|0|prod_lo|`.
|
||||
c0_c1_lo >>= 1;
|
||||
c2_c3_lo >>= 1;
|
||||
c4_c5_lo >>= 1;
|
||||
c6_c7_lo >>= 1;
|
||||
|
||||
let lo: PackedM31 = unsafe { transmute([c0_c1_lo, c2_c3_lo, c4_c5_lo, c6_c7_lo]) };
|
||||
let hi: PackedM31 = unsafe { transmute([c0_c1_hi, c2_c3_hi, c4_c5_hi, c6_c7_hi]) };
|
||||
|
||||
lo + hi
|
||||
}
|
||||
|
||||
/// Returns `a * b`.
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
pub(crate) fn _mul_wasm(a: PackedM31, b: PackedM31) -> PackedM31 {
|
||||
_mul_doubled_wasm(a, b.0 + b.0)
|
||||
}
|
||||
|
||||
/// Returns `a * b`.
|
||||
///
|
||||
/// `b_double` should be in the range `[0, 2P]`.
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
pub(crate) fn _mul_doubled_wasm(a: PackedM31, b_double: u32x16) -> PackedM31 {
|
||||
use core::arch::wasm32::{i64x2_extmul_high_u32x4, i64x2_extmul_low_u32x4, v128};
|
||||
use std::simd::u32x4;
|
||||
|
||||
let [a0, a1, a2, a3]: [v128; 4] = unsafe { transmute(a) };
|
||||
let [b_double0, b_double1, b_double2, b_double3]: [v128; 4] = unsafe { transmute(b_double) };
|
||||
|
||||
let c0_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a0, b_double0)) };
|
||||
let c0_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a0, b_double0)) };
|
||||
let c1_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a1, b_double1)) };
|
||||
let c1_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a1, b_double1)) };
|
||||
let c2_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a2, b_double2)) };
|
||||
let c2_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a2, b_double2)) };
|
||||
let c3_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a3, b_double3)) };
|
||||
let c3_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a3, b_double3)) };
|
||||
|
||||
let (mut c0_even, c0_odd) = c0_lo.deinterleave(c0_hi);
|
||||
let (mut c1_even, c1_odd) = c1_lo.deinterleave(c1_hi);
|
||||
let (mut c2_even, c2_odd) = c2_lo.deinterleave(c2_hi);
|
||||
let (mut c3_even, c3_odd) = c3_lo.deinterleave(c3_hi);
|
||||
c0_even >>= 1;
|
||||
c1_even >>= 1;
|
||||
c2_even >>= 1;
|
||||
c3_even >>= 1;
|
||||
|
||||
let even: PackedM31 = unsafe { transmute([c0_even, c1_even, c2_even, c3_even]) };
|
||||
let odd: PackedM31 = unsafe { transmute([c0_odd, c1_odd, c2_odd, c3_odd]) };
|
||||
|
||||
even + odd
|
||||
}
|
||||
|
||||
/// Returns `a * b`.
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
pub(crate) fn _mul_avx512(a: PackedM31, b: PackedM31) -> PackedM31 {
|
||||
_mul_doubled_avx512(a, b.0 + b.0)
|
||||
}
|
||||
|
||||
/// Returns `a * b`.
|
||||
///
|
||||
/// `b_double` should be in the range `[0, 2P]`.
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
pub(crate) fn _mul_doubled_avx512(a: PackedM31, b_double: u32x16) -> PackedM31 {
|
||||
use std::arch::x86_64::{__m512i, _mm512_mul_epu32, _mm512_srli_epi64};
|
||||
|
||||
let a: __m512i = unsafe { transmute(a) };
|
||||
let b_double: __m512i = unsafe { transmute(b_double) };
|
||||
|
||||
// Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of
|
||||
// the first operand.
|
||||
let a_e = a;
|
||||
// Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of
|
||||
// the first operand.
|
||||
let a_o = unsafe { _mm512_srli_epi64(a, 32) };
|
||||
|
||||
let b_dbl_e = b_double;
|
||||
let b_dbl_o = unsafe { _mm512_srli_epi64(b_double, 32) };
|
||||
|
||||
// To compute prod = a * b start by multiplying a_e/odd by b_dbl_e/odd.
|
||||
let prod_dbl_e: u32x16 = unsafe { transmute(_mm512_mul_epu32(a_e, b_dbl_e)) };
|
||||
let prod_dbl_o: u32x16 = unsafe { transmute(_mm512_mul_epu32(a_o, b_dbl_o)) };
|
||||
|
||||
// The result of a multiplication holds a*b in as 64-bits.
|
||||
// Each 64b-bit word looks like this:
|
||||
// 1 31 31 1
|
||||
// prod_dbl_e - |0|prod_e_h|prod_e_l|0|
|
||||
// prod_dbl_o - |0|prod_o_h|prod_o_l|0|
|
||||
|
||||
// Interleave the even words of prod_dbl_e with the even words of prod_dbl_o:
|
||||
let mut prod_lo = InterleaveEvens::concat_swizzle(prod_dbl_e, prod_dbl_o);
|
||||
// prod_lo - |prod_dbl_o_l|0|prod_dbl_e_l|0|
|
||||
// Divide by 2:
|
||||
prod_lo >>= 1;
|
||||
// prod_lo - |0|prod_o_l|0|prod_e_l|
|
||||
|
||||
// Interleave the odd words of prod_dbl_e with the odd words of prod_dbl_o:
|
||||
let prod_hi = InterleaveOdds::concat_swizzle(prod_dbl_e, prod_dbl_o);
|
||||
// prod_hi - |0|prod_o_h|0|prod_e_h|
|
||||
|
||||
PackedM31(prod_lo) + PackedM31(prod_hi)
|
||||
}
|
||||
|
||||
/// Returns `a * b`.
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
pub(crate) fn _mul_avx2(a: PackedM31, b: PackedM31) -> PackedM31 {
|
||||
_mul_doubled_avx2(a, b.0 + b.0)
|
||||
}
|
||||
|
||||
/// Returns `a * b`.
|
||||
///
|
||||
/// `b_double` should be in the range `[0, 2P]`.
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
pub(crate) fn _mul_doubled_avx2(a: PackedM31, b_double: u32x16) -> PackedM31 {
|
||||
use std::arch::x86_64::{__m256i, _mm256_mul_epu32, _mm256_srli_epi64};
|
||||
|
||||
let [a0, a1]: [__m256i; 2] = unsafe { transmute(a) };
|
||||
let [b0_dbl, b1_dbl]: [__m256i; 2] = unsafe { transmute(b_double) };
|
||||
|
||||
// Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of
|
||||
// the first operand.
|
||||
let a0_e = a0;
|
||||
let a1_e = a1;
|
||||
// Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of
|
||||
// the first operand.
|
||||
let a0_o = unsafe { _mm256_srli_epi64(a0, 32) };
|
||||
let a1_o = unsafe { _mm256_srli_epi64(a1, 32) };
|
||||
|
||||
let b0_dbl_e = b0_dbl;
|
||||
let b1_dbl_e = b1_dbl;
|
||||
let b0_dbl_o = unsafe { _mm256_srli_epi64(b0_dbl, 32) };
|
||||
let b1_dbl_o = unsafe { _mm256_srli_epi64(b1_dbl, 32) };
|
||||
|
||||
// To compute prod = a * b start by multiplying a0/1_e/odd by b0/1_e/odd.
|
||||
let prod0_dbl_e = unsafe { _mm256_mul_epu32(a0_e, b0_dbl_e) };
|
||||
let prod0_dbl_o = unsafe { _mm256_mul_epu32(a0_o, b0_dbl_o) };
|
||||
let prod1_dbl_e = unsafe { _mm256_mul_epu32(a1_e, b1_dbl_e) };
|
||||
let prod1_dbl_o = unsafe { _mm256_mul_epu32(a1_o, b1_dbl_o) };
|
||||
|
||||
let prod_dbl_e: u32x16 = unsafe { transmute([prod0_dbl_e, prod1_dbl_e]) };
|
||||
let prod_dbl_o: u32x16 = unsafe { transmute([prod0_dbl_o, prod1_dbl_o]) };
|
||||
|
||||
// The result of a multiplication holds a*b in as 64-bits.
|
||||
// Each 64b-bit word looks like this:
|
||||
// 1 31 31 1
|
||||
// prod_dbl_e - |0|prod_e_h|prod_e_l|0|
|
||||
// prod_dbl_o - |0|prod_o_h|prod_o_l|0|
|
||||
|
||||
// Interleave the even words of prod_dbl_e with the even words of prod_dbl_o:
|
||||
let mut prod_lo = InterleaveEvens::concat_swizzle(prod_dbl_e, prod_dbl_o);
|
||||
// prod_lo - |prod_dbl_o_l|0|prod_dbl_e_l|0|
|
||||
// Divide by 2:
|
||||
prod_lo >>= 1;
|
||||
// prod_lo - |0|prod_o_l|0|prod_e_l|
|
||||
|
||||
// Interleave the odd words of prod_dbl_e with the odd words of prod_dbl_o:
|
||||
let prod_hi = InterleaveOdds::concat_swizzle(prod_dbl_e, prod_dbl_o);
|
||||
// prod_hi - |0|prod_o_h|0|prod_e_h|
|
||||
|
||||
PackedM31(prod_lo) + PackedM31(prod_hi)
|
||||
}
|
||||
|
||||
/// Returns `a * b`.
|
||||
///
|
||||
/// Should only be used in the absence of a platform specific implementation.
|
||||
pub(crate) fn _mul_simd(a: PackedM31, b: PackedM31) -> PackedM31 {
|
||||
_mul_doubled_simd(a, b.0 + b.0)
|
||||
}
|
||||
|
||||
/// Returns `a * b`.
|
||||
///
|
||||
/// Should only be used in the absence of a platform specific implementation.
|
||||
///
|
||||
/// `b_double` should be in the range `[0, 2P]`.
|
||||
pub(crate) fn _mul_doubled_simd(a: PackedM31, b_double: u32x16) -> PackedM31 {
|
||||
const MASK_EVENS: Simd<u64, { N_LANES / 2 }> = Simd::from_array([0xFFFFFFFF; { N_LANES / 2 }]);
|
||||
|
||||
// Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of
|
||||
// the first operand.
|
||||
let a_e = unsafe { transmute::<_, Simd<u64, { N_LANES / 2 }>>(a.0) & MASK_EVENS };
|
||||
// Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of
|
||||
// the first operand.
|
||||
let a_o = unsafe { transmute::<_, Simd<u64, { N_LANES / 2 }>>(a) >> 32 };
|
||||
|
||||
let b_dbl_e = unsafe { transmute::<_, Simd<u64, { N_LANES / 2 }>>(b_double) & MASK_EVENS };
|
||||
let b_dbl_o = unsafe { transmute::<_, Simd<u64, { N_LANES / 2 }>>(b_double) >> 32 };
|
||||
|
||||
// To compute prod = a * b start by multiplying
|
||||
// a_e/o by b_dbl_e/o.
|
||||
let prod_e_dbl = a_e * b_dbl_e;
|
||||
let prod_o_dbl = a_o * b_dbl_o;
|
||||
|
||||
// The result of a multiplication holds a*b in as 64-bits.
|
||||
// Each 64b-bit word looks like this:
|
||||
// 1 31 31 1
|
||||
// prod_e_dbl - |0|prod_e_h|prod_e_l|0|
|
||||
// prod_o_dbl - |0|prod_o_h|prod_o_l|0|
|
||||
|
||||
// Interleave the even words of prod_e_dbl with the even words of prod_o_dbl:
|
||||
// let prod_lows = _mm512_permutex2var_epi32(prod_e_dbl, EVENS_INTERLEAVE_EVENS,
|
||||
// prod_o_dbl);
|
||||
// prod_ls - |prod_o_l|0|prod_e_l|0|
|
||||
let mut prod_lows = InterleaveEvens::concat_swizzle(
|
||||
unsafe { transmute::<_, Simd<u32, N_LANES>>(prod_e_dbl) },
|
||||
unsafe { transmute::<_, Simd<u32, N_LANES>>(prod_o_dbl) },
|
||||
);
|
||||
// Divide by 2:
|
||||
prod_lows >>= 1;
|
||||
// prod_ls - |0|prod_o_l|0|prod_e_l|
|
||||
|
||||
// Interleave the odd words of prod_e_dbl with the odd words of prod_o_dbl:
|
||||
let prod_highs = InterleaveOdds::concat_swizzle(
|
||||
unsafe { transmute::<_, Simd<u32, N_LANES>>(prod_e_dbl) },
|
||||
unsafe { transmute::<_, Simd<u32, N_LANES>>(prod_o_dbl) },
|
||||
);
|
||||
|
||||
// prod_hs - |0|prod_o_h|0|prod_e_h|
|
||||
PackedM31(prod_lows) + PackedM31(prod_highs)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::array;
|
||||
|
||||
use aligned::{Aligned, A64};
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
|
||||
use super::PackedM31;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::FieldExpOps;
|
||||
|
||||
#[test]
|
||||
fn addition_works() {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let lhs = rng.gen();
|
||||
let rhs = rng.gen();
|
||||
let packed_lhs = PackedM31::from_array(lhs);
|
||||
let packed_rhs = PackedM31::from_array(rhs);
|
||||
|
||||
let res = packed_lhs + packed_rhs;
|
||||
|
||||
assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] + rhs[i]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn subtraction_works() {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let lhs = rng.gen();
|
||||
let rhs = rng.gen();
|
||||
let packed_lhs = PackedM31::from_array(lhs);
|
||||
let packed_rhs = PackedM31::from_array(rhs);
|
||||
|
||||
let res = packed_lhs - packed_rhs;
|
||||
|
||||
assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] - rhs[i]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multiplication_works() {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let lhs = rng.gen();
|
||||
let rhs = rng.gen();
|
||||
let packed_lhs = PackedM31::from_array(lhs);
|
||||
let packed_rhs = PackedM31::from_array(rhs);
|
||||
|
||||
let res = packed_lhs * packed_rhs;
|
||||
|
||||
assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] * rhs[i]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn negation_works() {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let values = rng.gen();
|
||||
let packed_values = PackedM31::from_array(values);
|
||||
|
||||
let res = -packed_values;
|
||||
|
||||
assert_eq!(res.to_array(), array::from_fn(|i| -values[i]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_works() {
|
||||
let v: Aligned<A64, [u32; 16]> = Aligned(array::from_fn(|i| i as u32));
|
||||
|
||||
let res = unsafe { PackedM31::load(v.as_ptr()) };
|
||||
|
||||
assert_eq!(res.to_array().map(|v| v.0), *v);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn store_works() {
|
||||
let v = PackedM31::from_array(array::from_fn(BaseField::from));
|
||||
|
||||
let mut res: Aligned<A64, [u32; 16]> = Aligned([0; 16]);
|
||||
unsafe { v.store(res.as_mut_ptr()) };
|
||||
|
||||
assert_eq!(*res, v.to_array().map(|v| v.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inverse_works() {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let values = rng.gen();
|
||||
let packed_values = PackedM31::from_array(values);
|
||||
|
||||
let res = packed_values.inverse();
|
||||
|
||||
assert_eq!(res.to_array(), array::from_fn(|i| values[i].inverse()));
|
||||
}
|
||||
}
|
||||
41
Stwo_wrapper/crates/prover/src/core/backend/simd/mod.rs
Normal file
41
Stwo_wrapper/crates/prover/src/core/backend/simd/mod.rs
Normal file
@ -0,0 +1,41 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::{Backend, BackendForChannel};
|
||||
use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::core::vcs::poseidon252_merkle::Poseidon252MerkleChannel;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::core::vcs::poseidon_bls_merkle::PoseidonBLSMerkleChannel;
|
||||
|
||||
pub mod accumulation;
|
||||
pub mod bit_reverse;
|
||||
pub mod blake2s;
|
||||
pub mod circle;
|
||||
pub mod cm31;
|
||||
pub mod column;
|
||||
pub mod domain;
|
||||
pub mod fft;
|
||||
pub mod fri;
|
||||
mod grind;
|
||||
pub mod lookups;
|
||||
pub mod m31;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub mod poseidon252;
|
||||
pub mod prefix_sum;
|
||||
pub mod qm31;
|
||||
pub mod quotients;
|
||||
mod utils;
|
||||
pub mod very_packed_m31;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub mod poseidon_bls;
|
||||
|
||||
#[derive(Copy, Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct SimdBackend;
|
||||
|
||||
impl Backend for SimdBackend {}
|
||||
impl BackendForChannel<Blake2sMerkleChannel> for SimdBackend {}
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
impl BackendForChannel<Poseidon252MerkleChannel> for SimdBackend {}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
impl BackendForChannel<PoseidonBLSMerkleChannel> for SimdBackend {}
|
||||
@ -0,0 +1,36 @@
|
||||
use itertools::Itertools;
|
||||
use starknet_ff::FieldElement as FieldElement252;
|
||||
|
||||
use super::SimdBackend;
|
||||
use crate::core::backend::{Col, Column, ColumnOps};
|
||||
use crate::core::fields::m31::BaseField;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::core::vcs::ops::MerkleHasher;
|
||||
use crate::core::vcs::ops::MerkleOps;
|
||||
use crate::core::vcs::poseidon252_merkle::Poseidon252MerkleHasher;
|
||||
|
||||
impl ColumnOps<FieldElement252> for SimdBackend {
|
||||
type Column = Vec<FieldElement252>;
|
||||
|
||||
fn bit_reverse_column(_column: &mut Self::Column) {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
impl MerkleOps<Poseidon252MerkleHasher> for SimdBackend {
|
||||
// TODO(ShaharS): replace with SIMD implementation.
|
||||
fn commit_on_layer(
|
||||
log_size: u32,
|
||||
prev_layer: Option<&Vec<FieldElement252>>,
|
||||
columns: &[&Col<Self, BaseField>],
|
||||
) -> Vec<FieldElement252> {
|
||||
(0..(1 << log_size))
|
||||
.map(|i| {
|
||||
Poseidon252MerkleHasher::hash_node(
|
||||
prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])),
|
||||
&columns.iter().map(|column| column.at(i)).collect_vec(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,36 @@
|
||||
use itertools::Itertools;
|
||||
use ark_bls12_381::Fr as BlsFr;
|
||||
|
||||
use super::SimdBackend;
|
||||
use crate::core::backend::{Col, Column, ColumnOps};
|
||||
use crate::core::fields::m31::BaseField;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::core::vcs::ops::MerkleHasher;
|
||||
use crate::core::vcs::ops::MerkleOps;
|
||||
use crate::core::vcs::poseidon_bls_merkle::PoseidonBLSMerkleHasher;
|
||||
|
||||
impl ColumnOps<BlsFr> for SimdBackend {
|
||||
type Column = Vec<BlsFr>;
|
||||
|
||||
fn bit_reverse_column(_column: &mut Self::Column) {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
impl MerkleOps<PoseidonBLSMerkleHasher> for SimdBackend {
|
||||
// TODO(ShaharS): replace with SIMD implementation.
|
||||
fn commit_on_layer(
|
||||
log_size: u32,
|
||||
prev_layer: Option<&Vec<BlsFr>>,
|
||||
columns: &[&Col<Self, BaseField>],
|
||||
) -> Vec<BlsFr> {
|
||||
(0..(1 << log_size))
|
||||
.map(|i| {
|
||||
PoseidonBLSMerkleHasher::hash_node(
|
||||
prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])),
|
||||
&columns.iter().map(|column| column.at(i)).collect_vec(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
188
Stwo_wrapper/crates/prover/src/core/backend/simd/prefix_sum.rs
Normal file
188
Stwo_wrapper/crates/prover/src/core/backend/simd/prefix_sum.rs
Normal file
@ -0,0 +1,188 @@
|
||||
use std::iter::zip;
|
||||
use std::ops::{AddAssign, Sub};
|
||||
|
||||
use itertools::{izip, Itertools};
|
||||
use num_traits::Zero;
|
||||
|
||||
use crate::core::backend::simd::m31::{PackedBaseField, N_LANES};
|
||||
use crate::core::backend::simd::SimdBackend;
|
||||
use crate::core::backend::{Col, Column};
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::utils::{
|
||||
bit_reverse, circle_domain_order_to_coset_order, coset_order_to_circle_domain_order,
|
||||
};
|
||||
|
||||
/// Performs a inclusive prefix sum on values in `Coset` order when provided
|
||||
/// with evaluations in bit-reversed `CircleDomain` order.
|
||||
///
|
||||
/// Based on parallel Blelloch prefix sum:
|
||||
/// <https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda>
|
||||
pub fn inclusive_prefix_sum(
|
||||
bit_rev_circle_domain_evals: Col<SimdBackend, BaseField>,
|
||||
) -> Col<SimdBackend, BaseField> {
|
||||
if bit_rev_circle_domain_evals.len() < N_LANES * 4 {
|
||||
return inclusive_prefix_sum_slow(bit_rev_circle_domain_evals);
|
||||
}
|
||||
|
||||
let mut res = bit_rev_circle_domain_evals;
|
||||
let packed_len = res.data.len();
|
||||
let (l_half, r_half) = res.data.split_at_mut(packed_len / 2);
|
||||
|
||||
// Up Sweep
|
||||
// ========
|
||||
// Handle the first two up sweep rounds manually.
|
||||
// Required due different ordering of `CircleDomain` and `Coset`.
|
||||
// Evaluations are provided in bit-reversed `CircleDomain` order.
|
||||
for ([l0, l1], [r0, r1]) in izip!(l_half.array_chunks_mut(), r_half.array_chunks_mut().rev()) {
|
||||
let (mut half_coset0_lo, half_coset1_hi_rev) = l0.deinterleave(*l1);
|
||||
let half_coset1_hi = half_coset1_hi_rev.reverse();
|
||||
let (mut half_coset0_hi, half_coset1_lo_rev) = r0.deinterleave(*r1);
|
||||
let half_coset1_lo = half_coset1_lo_rev.reverse();
|
||||
up_sweep_val(&mut half_coset0_lo, half_coset1_lo);
|
||||
up_sweep_val(&mut half_coset0_hi, half_coset1_hi);
|
||||
*l0 = half_coset0_lo;
|
||||
*l1 = half_coset0_hi;
|
||||
*r0 = half_coset1_lo;
|
||||
*r1 = half_coset1_hi;
|
||||
}
|
||||
let half_coset0_sums: &mut [PackedBaseField] = l_half;
|
||||
for i in 0..half_coset0_sums.len() / 2 {
|
||||
let lo_index = i * 2;
|
||||
let hi_index = half_coset0_sums.len() - 1 - i * 2;
|
||||
let hi = half_coset0_sums[hi_index];
|
||||
up_sweep_val(&mut half_coset0_sums[lo_index], hi)
|
||||
}
|
||||
// Handle remaining up sweep rounds.
|
||||
let mut chunk_size = half_coset0_sums.len() / 2;
|
||||
while chunk_size > 1 {
|
||||
let (lows, highs) = half_coset0_sums.split_at_mut(chunk_size);
|
||||
zip(lows.array_chunks_mut(), highs.array_chunks())
|
||||
.for_each(|([lo, _], [hi, _])| up_sweep_val(lo, *hi));
|
||||
chunk_size /= 2;
|
||||
}
|
||||
// Up sweep the last SIMD vector.
|
||||
let mut first_vec = half_coset0_sums.first().unwrap().to_array();
|
||||
let mut chunk_size = first_vec.len() / 2;
|
||||
while chunk_size > 0 {
|
||||
let (lows, highs) = first_vec.split_at_mut(chunk_size);
|
||||
zip(lows, highs).for_each(|(lo, hi)| up_sweep_val(lo, *hi));
|
||||
chunk_size /= 2;
|
||||
}
|
||||
|
||||
// Down Sweep
|
||||
// ==========
|
||||
// Down sweep the last SIMD vector.
|
||||
let mut chunk_size = 1;
|
||||
while chunk_size < first_vec.len() {
|
||||
let (lows, highs) = first_vec.split_at_mut(chunk_size);
|
||||
zip(lows, highs).for_each(|(lo, hi)| down_sweep_val(lo, hi));
|
||||
chunk_size *= 2;
|
||||
}
|
||||
// Re-insert the SIMD vector.
|
||||
*half_coset0_sums.first_mut().unwrap() = first_vec.into();
|
||||
// Handle remaining down sweep rounds (except first two).
|
||||
let mut chunk_size = 2;
|
||||
while chunk_size < half_coset0_sums.len() {
|
||||
let (lows, highs) = half_coset0_sums.split_at_mut(chunk_size);
|
||||
zip(lows.array_chunks_mut(), highs.array_chunks_mut())
|
||||
.for_each(|([lo, _], [hi, _])| down_sweep_val(lo, hi));
|
||||
chunk_size *= 2;
|
||||
}
|
||||
// Handle last two down sweep rounds manually.
|
||||
// Required due different ordering of `CircleDomain` and `Coset`.
|
||||
// Evaluations must be returned in bit-reversed `CircleDomain` order.
|
||||
for i in 0..half_coset0_sums.len() / 2 {
|
||||
let lo_index = i * 2;
|
||||
let hi_index = half_coset0_sums.len() - 1 - i * 2;
|
||||
let (mut lo, mut hi) = (half_coset0_sums[lo_index], half_coset0_sums[hi_index]);
|
||||
down_sweep_val(&mut lo, &mut hi);
|
||||
(half_coset0_sums[lo_index], half_coset0_sums[hi_index]) = (lo, hi);
|
||||
}
|
||||
for ([l0, l1], [r0, r1]) in izip!(l_half.array_chunks_mut(), r_half.array_chunks_mut().rev()) {
|
||||
let mut half_coset0_lo = *l0;
|
||||
let mut half_coset1_lo = *r0;
|
||||
down_sweep_val(&mut half_coset0_lo, &mut half_coset1_lo);
|
||||
let mut half_coset0_hi = *l1;
|
||||
let mut half_coset1_hi = *r1;
|
||||
down_sweep_val(&mut half_coset0_hi, &mut half_coset1_hi);
|
||||
(*l0, *l1) = half_coset0_lo.interleave(half_coset1_hi.reverse());
|
||||
(*r0, *r1) = half_coset0_hi.interleave(half_coset1_lo.reverse());
|
||||
}
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
fn up_sweep_val<F: AddAssign + Copy>(lo: &mut F, hi: F) {
|
||||
*lo += hi;
|
||||
}
|
||||
|
||||
fn down_sweep_val<F: Sub<Output = F> + Copy>(lo: &mut F, hi: &mut F) {
|
||||
(*lo, *hi) = (*lo - *hi, *lo)
|
||||
}
|
||||
|
||||
fn inclusive_prefix_sum_slow(
|
||||
bit_rev_circle_domain_evals: Col<SimdBackend, BaseField>,
|
||||
) -> Col<SimdBackend, BaseField> {
|
||||
// Obtain values in coset order.
|
||||
let mut coset_order_eval = bit_rev_circle_domain_evals.into_cpu_vec();
|
||||
bit_reverse(&mut coset_order_eval);
|
||||
coset_order_eval = circle_domain_order_to_coset_order(&coset_order_eval);
|
||||
let coset_order_prefix_sum = coset_order_eval
|
||||
.into_iter()
|
||||
.scan(BaseField::zero(), |acc, v| {
|
||||
*acc += v;
|
||||
Some(*acc)
|
||||
})
|
||||
.collect_vec();
|
||||
let mut circle_domain_order_eval = coset_order_to_circle_domain_order(&coset_order_prefix_sum);
|
||||
bit_reverse(&mut circle_domain_order_eval);
|
||||
circle_domain_order_eval.into_iter().collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use test_log::test;
|
||||
|
||||
use super::inclusive_prefix_sum;
|
||||
use crate::core::backend::simd::column::BaseColumn;
|
||||
use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum_slow;
|
||||
use crate::core::backend::Column;
|
||||
|
||||
#[test]
|
||||
fn exclusive_prefix_sum_simd_with_log_size_3_works() {
|
||||
const LOG_N: u32 = 3;
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let evals: BaseColumn = (0..1 << LOG_N).map(|_| rng.gen()).collect();
|
||||
let expected = inclusive_prefix_sum_slow(evals.clone());
|
||||
|
||||
let res = inclusive_prefix_sum(evals);
|
||||
|
||||
assert_eq!(res.to_cpu(), expected.to_cpu());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn exclusive_prefix_sum_simd_with_log_size_6_works() {
|
||||
const LOG_N: u32 = 6;
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let evals: BaseColumn = (0..1 << LOG_N).map(|_| rng.gen()).collect();
|
||||
let expected = inclusive_prefix_sum_slow(evals.clone());
|
||||
|
||||
let res = inclusive_prefix_sum(evals);
|
||||
|
||||
assert_eq!(res.to_cpu(), expected.to_cpu());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn exclusive_prefix_sum_simd_with_log_size_8_works() {
|
||||
const LOG_N: u32 = 8;
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let evals: BaseColumn = (0..1 << LOG_N).map(|_| rng.gen()).collect();
|
||||
let expected = inclusive_prefix_sum_slow(evals.clone());
|
||||
|
||||
let res = inclusive_prefix_sum(evals);
|
||||
|
||||
assert_eq!(res.to_cpu(), expected.to_cpu());
|
||||
}
|
||||
}
|
||||
357
Stwo_wrapper/crates/prover/src/core/backend/simd/qm31.rs
Normal file
357
Stwo_wrapper/crates/prover/src/core/backend/simd/qm31.rs
Normal file
@ -0,0 +1,357 @@
|
||||
use std::array;
|
||||
use std::iter::Sum;
|
||||
use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
|
||||
|
||||
use bytemuck::{Pod, Zeroable};
|
||||
use num_traits::{One, Zero};
|
||||
use rand::distributions::{Distribution, Standard};
|
||||
|
||||
use super::cm31::PackedCM31;
|
||||
use super::m31::{PackedM31, N_LANES};
|
||||
use crate::core::fields::qm31::QM31;
|
||||
use crate::core::fields::FieldExpOps;
|
||||
|
||||
pub type PackedSecureField = PackedQM31;
|
||||
|
||||
/// SIMD implementation of [`QM31`].
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct PackedQM31(pub [PackedCM31; 2]);
|
||||
|
||||
impl PackedQM31 {
|
||||
/// Constructs a new instance with all vector elements set to `value`.
|
||||
pub fn broadcast(value: QM31) -> Self {
|
||||
Self([
|
||||
PackedCM31::broadcast(value.0),
|
||||
PackedCM31::broadcast(value.1),
|
||||
])
|
||||
}
|
||||
|
||||
/// Returns all `a` values such that each vector element is represented as `a + bu`.
|
||||
pub fn a(&self) -> PackedCM31 {
|
||||
self.0[0]
|
||||
}
|
||||
|
||||
/// Returns all `b` values such that each vector element is represented as `a + bu`.
|
||||
pub fn b(&self) -> PackedCM31 {
|
||||
self.0[1]
|
||||
}
|
||||
|
||||
pub fn to_array(&self) -> [QM31; N_LANES] {
|
||||
let a = self.a().to_array();
|
||||
let b = self.b().to_array();
|
||||
array::from_fn(|i| QM31(a[i], b[i]))
|
||||
}
|
||||
|
||||
pub fn from_array(values: [QM31; N_LANES]) -> Self {
|
||||
let a = values.map(|v| v.0);
|
||||
let b = values.map(|v| v.1);
|
||||
Self([PackedCM31::from_array(a), PackedCM31::from_array(b)])
|
||||
}
|
||||
|
||||
/// Interleaves two vectors.
|
||||
pub fn interleave(self, other: Self) -> (Self, Self) {
|
||||
let Self([a_evens, b_evens]) = self;
|
||||
let Self([a_odds, b_odds]) = other;
|
||||
let (a_lhs, a_rhs) = a_evens.interleave(a_odds);
|
||||
let (b_lhs, b_rhs) = b_evens.interleave(b_odds);
|
||||
(Self([a_lhs, b_lhs]), Self([a_rhs, b_rhs]))
|
||||
}
|
||||
|
||||
/// Deinterleaves two vectors.
|
||||
pub fn deinterleave(self, other: Self) -> (Self, Self) {
|
||||
let Self([a_lhs, b_lhs]) = self;
|
||||
let Self([a_rhs, b_rhs]) = other;
|
||||
let (a_evens, a_odds) = a_lhs.deinterleave(a_rhs);
|
||||
let (b_evens, b_odds) = b_lhs.deinterleave(b_rhs);
|
||||
(Self([a_evens, b_evens]), Self([a_odds, b_odds]))
|
||||
}
|
||||
|
||||
/// Sums all the elements in the vector.
|
||||
pub fn pointwise_sum(self) -> QM31 {
|
||||
self.to_array().into_iter().sum()
|
||||
}
|
||||
|
||||
/// Doubles each element in the vector.
|
||||
pub fn double(self) -> Self {
|
||||
let Self([a, b]) = self;
|
||||
Self([a.double(), b.double()])
|
||||
}
|
||||
|
||||
/// Returns vectors `a, b, c, d` such that element `i` is represented as
|
||||
/// `QM31(a_i, b_i, c_i, d_i)`.
|
||||
pub fn into_packed_m31s(self) -> [PackedM31; 4] {
|
||||
let Self([PackedCM31([a, b]), PackedCM31([c, d])]) = self;
|
||||
[a, b, c, d]
|
||||
}
|
||||
|
||||
/// Creates an instance from vectors `a, b, c, d` such that element `i`
|
||||
/// is represented as `QM31(a_i, b_i, c_i, d_i)`.
|
||||
pub fn from_packed_m31s([a, b, c, d]: [PackedM31; 4]) -> Self {
|
||||
Self([PackedCM31([a, b]), PackedCM31([c, d])])
|
||||
}
|
||||
}
|
||||
|
||||
impl Add for PackedQM31 {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
Self([self.a() + rhs.a(), self.b() + rhs.b()])
|
||||
}
|
||||
}
|
||||
|
||||
impl Sub for PackedQM31 {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
Self([self.a() - rhs.a(), self.b() - rhs.b()])
|
||||
}
|
||||
}
|
||||
|
||||
impl Mul for PackedQM31 {
|
||||
type Output = Self;
|
||||
|
||||
fn mul(self, rhs: Self) -> Self::Output {
|
||||
// Compute using Karatsuba.
|
||||
// (a + ub) * (c + ud) =
|
||||
// (ac + (2+i)bd) + (ad + bc)u =
|
||||
// ac + 2bd + ibd + (ad + bc)u.
|
||||
let ac = self.a() * rhs.a();
|
||||
let bd = self.b() * rhs.b();
|
||||
let bd_times_1_plus_i = PackedCM31([bd.a() - bd.b(), bd.a() + bd.b()]);
|
||||
// Computes ac + bd.
|
||||
let ac_p_bd = ac + bd;
|
||||
// Computes ad + bc.
|
||||
let ad_p_bc = (self.a() + self.b()) * (rhs.a() + rhs.b()) - ac_p_bd;
|
||||
// ac + 2bd + ibd =
|
||||
// ac + bd + bd + ibd
|
||||
let l = PackedCM31([
|
||||
ac_p_bd.a() + bd_times_1_plus_i.a(),
|
||||
ac_p_bd.b() + bd_times_1_plus_i.b(),
|
||||
]);
|
||||
Self([l, ad_p_bc])
|
||||
}
|
||||
}
|
||||
|
||||
impl Zero for PackedQM31 {
|
||||
fn zero() -> Self {
|
||||
Self([PackedCM31::zero(), PackedCM31::zero()])
|
||||
}
|
||||
|
||||
fn is_zero(&self) -> bool {
|
||||
self.a().is_zero() && self.b().is_zero()
|
||||
}
|
||||
}
|
||||
|
||||
impl One for PackedQM31 {
|
||||
fn one() -> Self {
|
||||
Self([PackedCM31::one(), PackedCM31::zero()])
|
||||
}
|
||||
}
|
||||
|
||||
impl AddAssign for PackedQM31 {
|
||||
fn add_assign(&mut self, rhs: Self) {
|
||||
*self = *self + rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl MulAssign for PackedQM31 {
|
||||
fn mul_assign(&mut self, rhs: Self) {
|
||||
*self = *self * rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl FieldExpOps for PackedQM31 {
|
||||
fn inverse(&self) -> Self {
|
||||
assert!(!self.is_zero(), "0 has no inverse");
|
||||
// (a + bu)^-1 = (a - bu) / (a^2 - (2+i)b^2).
|
||||
let b2 = self.b().square();
|
||||
let ib2 = PackedCM31([-b2.b(), b2.a()]);
|
||||
let denom = self.a().square() - (b2 + b2 + ib2);
|
||||
let denom_inverse = denom.inverse();
|
||||
Self([self.a() * denom_inverse, -self.b() * denom_inverse])
|
||||
}
|
||||
}
|
||||
|
||||
impl Add<PackedM31> for PackedQM31 {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: PackedM31) -> Self::Output {
|
||||
Self([self.a() + rhs, self.b()])
|
||||
}
|
||||
}
|
||||
|
||||
impl Mul<PackedM31> for PackedQM31 {
|
||||
type Output = Self;
|
||||
|
||||
fn mul(self, rhs: PackedM31) -> Self::Output {
|
||||
let Self([a, b]) = self;
|
||||
Self([a * rhs, b * rhs])
|
||||
}
|
||||
}
|
||||
|
||||
impl Mul<PackedCM31> for PackedQM31 {
|
||||
type Output = Self;
|
||||
|
||||
fn mul(self, rhs: PackedCM31) -> Self::Output {
|
||||
let Self([a, b]) = self;
|
||||
Self([a * rhs, b * rhs])
|
||||
}
|
||||
}
|
||||
|
||||
impl Sub<PackedM31> for PackedQM31 {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, rhs: PackedM31) -> Self::Output {
|
||||
let Self([a, b]) = self;
|
||||
Self([a - rhs, b])
|
||||
}
|
||||
}
|
||||
|
||||
impl Add<QM31> for PackedQM31 {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: QM31) -> Self::Output {
|
||||
self + PackedQM31::broadcast(rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl Sub<QM31> for PackedQM31 {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, rhs: QM31) -> Self::Output {
|
||||
self - PackedQM31::broadcast(rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl Mul<QM31> for PackedQM31 {
|
||||
type Output = Self;
|
||||
|
||||
fn mul(self, rhs: QM31) -> Self::Output {
|
||||
self * PackedQM31::broadcast(rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl SubAssign for PackedQM31 {
|
||||
fn sub_assign(&mut self, rhs: Self) {
|
||||
*self = *self - rhs;
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl Pod for PackedQM31 {}
|
||||
|
||||
unsafe impl Zeroable for PackedQM31 {
|
||||
fn zeroed() -> Self {
|
||||
unsafe { core::mem::zeroed() }
|
||||
}
|
||||
}
|
||||
|
||||
impl Sum for PackedQM31 {
|
||||
fn sum<I>(mut iter: I) -> Self
|
||||
where
|
||||
I: Iterator<Item = Self>,
|
||||
{
|
||||
let first = iter.next().unwrap_or_else(Self::zero);
|
||||
iter.fold(first, |a, b| a + b)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Sum<&'a Self> for PackedQM31 {
|
||||
fn sum<I>(iter: I) -> Self
|
||||
where
|
||||
I: Iterator<Item = &'a Self>,
|
||||
{
|
||||
iter.copied().sum()
|
||||
}
|
||||
}
|
||||
|
||||
impl Neg for PackedQM31 {
|
||||
type Output = Self;
|
||||
|
||||
fn neg(self) -> Self::Output {
|
||||
let Self([a, b]) = self;
|
||||
Self([-a, -b])
|
||||
}
|
||||
}
|
||||
|
||||
impl Distribution<PackedQM31> for Standard {
|
||||
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> PackedQM31 {
|
||||
PackedQM31::from_array(rng.gen())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PackedM31> for PackedQM31 {
|
||||
fn from(value: PackedM31) -> Self {
|
||||
PackedQM31::from_packed_m31s([
|
||||
value,
|
||||
PackedM31::zero(),
|
||||
PackedM31::zero(),
|
||||
PackedM31::zero(),
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<QM31> for PackedQM31 {
|
||||
fn from(value: QM31) -> Self {
|
||||
PackedQM31::broadcast(value)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::array;
|
||||
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
|
||||
use crate::core::backend::simd::qm31::PackedQM31;
|
||||
|
||||
#[test]
|
||||
fn addition_works() {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let lhs = rng.gen();
|
||||
let rhs = rng.gen();
|
||||
let packed_lhs = PackedQM31::from_array(lhs);
|
||||
let packed_rhs = PackedQM31::from_array(rhs);
|
||||
|
||||
let res = packed_lhs + packed_rhs;
|
||||
|
||||
assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] + rhs[i]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn subtraction_works() {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let lhs = rng.gen();
|
||||
let rhs = rng.gen();
|
||||
let packed_lhs = PackedQM31::from_array(lhs);
|
||||
let packed_rhs = PackedQM31::from_array(rhs);
|
||||
|
||||
let res = packed_lhs - packed_rhs;
|
||||
|
||||
assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] - rhs[i]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multiplication_works() {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let lhs = rng.gen();
|
||||
let rhs = rng.gen();
|
||||
let packed_lhs = PackedQM31::from_array(lhs);
|
||||
let packed_rhs = PackedQM31::from_array(rhs);
|
||||
|
||||
let res = packed_lhs * packed_rhs;
|
||||
|
||||
assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] * rhs[i]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn negation_works() {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let values = rng.gen();
|
||||
let packed_values = PackedQM31::from_array(values);
|
||||
|
||||
let res = -packed_values;
|
||||
|
||||
assert_eq!(res.to_array(), values.map(|v| -v));
|
||||
}
|
||||
}
|
||||
314
Stwo_wrapper/crates/prover/src/core/backend/simd/quotients.rs
Normal file
314
Stwo_wrapper/crates/prover/src/core/backend/simd/quotients.rs
Normal file
@ -0,0 +1,314 @@
|
||||
use itertools::{izip, zip_eq, Itertools};
|
||||
use num_traits::Zero;
|
||||
use tracing::{span, Level};
|
||||
|
||||
use super::cm31::PackedCM31;
|
||||
use super::column::CM31Column;
|
||||
use super::domain::CircleDomainBitRevIterator;
|
||||
use super::m31::{PackedBaseField, LOG_N_LANES, N_LANES};
|
||||
use super::qm31::PackedSecureField;
|
||||
use super::SimdBackend;
|
||||
use crate::core::backend::cpu::quotients::{batch_random_coeffs, column_line_coeffs};
|
||||
use crate::core::backend::Column;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::fields::secure_column::{SecureColumnByCoords, SECURE_EXTENSION_DEGREE};
|
||||
use crate::core::fields::FieldExpOps;
|
||||
use crate::core::pcs::quotients::{ColumnSampleBatch, QuotientOps};
|
||||
use crate::core::poly::circle::{CircleDomain, CircleEvaluation, PolyOps, SecureEvaluation};
|
||||
use crate::core::poly::BitReversedOrder;
|
||||
use crate::core::utils::bit_reverse;
|
||||
|
||||
pub struct QuotientConstants {
|
||||
pub line_coeffs: Vec<Vec<(SecureField, SecureField, SecureField)>>,
|
||||
pub batch_random_coeffs: Vec<SecureField>,
|
||||
pub denominator_inverses: Vec<CM31Column>,
|
||||
}
|
||||
|
||||
impl QuotientOps for SimdBackend {
|
||||
fn accumulate_quotients(
|
||||
domain: CircleDomain,
|
||||
columns: &[&CircleEvaluation<Self, BaseField, BitReversedOrder>],
|
||||
random_coeff: SecureField,
|
||||
sample_batches: &[ColumnSampleBatch],
|
||||
log_blowup_factor: u32,
|
||||
) -> SecureEvaluation<Self, BitReversedOrder> {
|
||||
// Split the domain into a subdomain and a shift coset.
|
||||
// TODO(spapini): Move to the caller when Columns support slices.
|
||||
let (subdomain, mut subdomain_shifts) = domain.split(log_blowup_factor);
|
||||
|
||||
// Bit reverse the shifts.
|
||||
// Since we traverse the domain in bit-reversed order, we need bit-reverse the shifts.
|
||||
// To see why, consider the index of a point in the natural order of the domain
|
||||
// (least to most):
|
||||
// b0 b1 b2 b3 b4 b5
|
||||
// b0 adds G, b1 adds 2G, etc.. (b5 is special and flips the sign of the point).
|
||||
// Splitting the domain to 4 parts yields:
|
||||
// subdomain: b2 b3 b4 b5, shifts: b0 b1.
|
||||
// b2 b3 b4 b5 is indeed a circle domain, with a bigger jump.
|
||||
// Traversing the domain in bit-reversed order, after we finish with b5, b4, b3, b2,
|
||||
// we need to change b1 and then b0. This is the bit reverse of the shift b0 b1.
|
||||
bit_reverse(&mut subdomain_shifts);
|
||||
|
||||
let (span, mut extended_eval, subeval_polys) = accumulate_quotients_on_subdomain(
|
||||
subdomain,
|
||||
sample_batches,
|
||||
random_coeff,
|
||||
columns,
|
||||
domain,
|
||||
);
|
||||
|
||||
// Extend the evaluation to the full domain.
|
||||
// TODO(spapini): Try to optimize out all these copies.
|
||||
for (ci, &c) in subdomain_shifts.iter().enumerate() {
|
||||
let subdomain = subdomain.shift(c);
|
||||
|
||||
let twiddles = SimdBackend::precompute_twiddles(subdomain.half_coset);
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
for i in 0..SECURE_EXTENSION_DEGREE {
|
||||
// Sanity check.
|
||||
let eval = subeval_polys[i].evaluate_with_twiddles(subdomain, &twiddles);
|
||||
extended_eval.columns[i].data[(ci * eval.data.len())..((ci + 1) * eval.data.len())]
|
||||
.copy_from_slice(&eval.data);
|
||||
}
|
||||
}
|
||||
span.exit();
|
||||
|
||||
SecureEvaluation::new(domain, extended_eval)
|
||||
}
|
||||
}
|
||||
|
||||
fn accumulate_quotients_on_subdomain(
|
||||
subdomain: CircleDomain,
|
||||
sample_batches: &[ColumnSampleBatch],
|
||||
random_coeff: SecureField,
|
||||
columns: &[&CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>],
|
||||
domain: CircleDomain,
|
||||
) -> (
|
||||
span::EnteredSpan,
|
||||
SecureColumnByCoords<SimdBackend>,
|
||||
[crate::core::poly::circle::CirclePoly<SimdBackend>; 4],
|
||||
) {
|
||||
assert!(subdomain.log_size() >= LOG_N_LANES + 2);
|
||||
let mut values =
|
||||
unsafe { SecureColumnByCoords::<SimdBackend>::uninitialized(subdomain.size()) };
|
||||
let quotient_constants = quotient_constants(sample_batches, random_coeff, subdomain);
|
||||
|
||||
let span = span!(Level::INFO, "Quotient accumulation").entered();
|
||||
for (quad_row, points) in CircleDomainBitRevIterator::new(subdomain)
|
||||
.array_chunks::<4>()
|
||||
.enumerate()
|
||||
{
|
||||
// TODO(spapini): Use optimized domain iteration.
|
||||
let (y01, _) = points[0].y.deinterleave(points[1].y);
|
||||
let (y23, _) = points[2].y.deinterleave(points[3].y);
|
||||
let (spaced_ys, _) = y01.deinterleave(y23);
|
||||
let row_accumulator = accumulate_row_quotients(
|
||||
sample_batches,
|
||||
columns,
|
||||
"ient_constants,
|
||||
quad_row,
|
||||
spaced_ys,
|
||||
);
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
for i in 0..4 {
|
||||
unsafe { values.set_packed((quad_row << 2) + i, row_accumulator[i]) };
|
||||
}
|
||||
}
|
||||
span.exit();
|
||||
let span = span!(Level::INFO, "Quotient extension").entered();
|
||||
|
||||
// Extend the evaluation to the full domain.
|
||||
let extended_eval =
|
||||
unsafe { SecureColumnByCoords::<SimdBackend>::uninitialized(domain.size()) };
|
||||
|
||||
let mut i = 0;
|
||||
let values = values.columns;
|
||||
let twiddles = SimdBackend::precompute_twiddles(subdomain.half_coset);
|
||||
let subeval_polys = values.map(|c| {
|
||||
i += 1;
|
||||
CircleEvaluation::<SimdBackend, BaseField, BitReversedOrder>::new(subdomain, c)
|
||||
.interpolate_with_twiddles(&twiddles)
|
||||
});
|
||||
(span, extended_eval, subeval_polys)
|
||||
}
|
||||
|
||||
/// Accumulates the quotients for 4 * N_LANES rows at a time.
|
||||
/// spaced_ys - y values for N_LANES points in the domain, in jumps of 4.
|
||||
pub fn accumulate_row_quotients(
|
||||
sample_batches: &[ColumnSampleBatch],
|
||||
columns: &[&CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>],
|
||||
quotient_constants: &QuotientConstants,
|
||||
quad_row: usize,
|
||||
spaced_ys: PackedBaseField,
|
||||
) -> [PackedSecureField; 4] {
|
||||
let mut row_accumulator = [PackedSecureField::zero(); 4];
|
||||
for (sample_batch, line_coeffs, batch_coeff, denominator_inverses) in izip!(
|
||||
sample_batches,
|
||||
"ient_constants.line_coeffs,
|
||||
"ient_constants.batch_random_coeffs,
|
||||
"ient_constants.denominator_inverses
|
||||
) {
|
||||
let mut numerator = [PackedSecureField::zero(); 4];
|
||||
for ((column_index, _), (a, b, c)) in zip_eq(&sample_batch.columns_and_values, line_coeffs)
|
||||
{
|
||||
let column = &columns[*column_index];
|
||||
let cvalues: [_; 4] = std::array::from_fn(|i| {
|
||||
PackedSecureField::broadcast(*c) * column.data[(quad_row << 2) + i]
|
||||
});
|
||||
|
||||
// The numerator is the line equation:
|
||||
// c * value - a * point.y - b;
|
||||
// Note that a, b, c were already multilpied by random_coeff^i.
|
||||
// See [column_line_coeffs()] for more details.
|
||||
// This is why we only add here.
|
||||
// 4 consecutive point in the domain in bit reversed order are:
|
||||
// P, -P, P + H, -P + H.
|
||||
// H being the half point (-1,0). The y values for these are
|
||||
// P.y, -P.y, -P.y, P.y.
|
||||
// We use this fact to save multiplications.
|
||||
// spaced_ys are the y value in jumps of 4:
|
||||
// P0.y, P1.y, P2.y, ...
|
||||
let spaced_ay = PackedSecureField::broadcast(*a) * spaced_ys;
|
||||
// t0:t1 = a*P0.y, -a*P0.y, a*P1.y, -a*P1.y, ...
|
||||
let (t0, t1) = spaced_ay.interleave(-spaced_ay);
|
||||
// t2:t3:t4:t5 = a*P0.y, -a*P0.y, -a*P0.y, a*P0.y, a*P1.y, -a*P1.y, ...
|
||||
let (t2, t3) = t0.interleave(-t0);
|
||||
let (t4, t5) = t1.interleave(-t1);
|
||||
let ay = [t2, t3, t4, t5];
|
||||
for i in 0..4 {
|
||||
numerator[i] += cvalues[i] - ay[i] - PackedSecureField::broadcast(*b);
|
||||
}
|
||||
}
|
||||
|
||||
for i in 0..4 {
|
||||
row_accumulator[i] = row_accumulator[i] * PackedSecureField::broadcast(*batch_coeff)
|
||||
+ numerator[i] * denominator_inverses.data[(quad_row << 2) + i];
|
||||
}
|
||||
}
|
||||
row_accumulator
|
||||
}
|
||||
|
||||
fn denominator_inverses(
|
||||
sample_batches: &[ColumnSampleBatch],
|
||||
domain: CircleDomain,
|
||||
) -> Vec<CM31Column> {
|
||||
// We want a P to be on a line that passes through a point Pr + uPi in QM31^2, and its conjugate
|
||||
// Pr - uPi. Thus, Pr - P is parallel to Pi. Or, (Pr - P).x * Pi.y - (Pr - P).y * Pi.x = 0.
|
||||
let flat_denominators: CM31Column = sample_batches
|
||||
.iter()
|
||||
.flat_map(|sample_batch| {
|
||||
// Extract Pr, Pi.
|
||||
let prx = PackedCM31::broadcast(sample_batch.point.x.0);
|
||||
let pry = PackedCM31::broadcast(sample_batch.point.y.0);
|
||||
let pix = PackedCM31::broadcast(sample_batch.point.x.1);
|
||||
let piy = PackedCM31::broadcast(sample_batch.point.y.1);
|
||||
|
||||
// Line equation through pr +-u pi.
|
||||
// (p-pr)*
|
||||
CircleDomainBitRevIterator::new(domain)
|
||||
.map(|points| (prx - points.x) * piy - (pry - points.y) * pix)
|
||||
.collect_vec()
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut flat_denominator_inverses =
|
||||
unsafe { CM31Column::uninitialized(flat_denominators.len()) };
|
||||
FieldExpOps::batch_inverse(
|
||||
&flat_denominators.data,
|
||||
&mut flat_denominator_inverses.data[..],
|
||||
);
|
||||
|
||||
flat_denominator_inverses
|
||||
.data
|
||||
.chunks(domain.size() / N_LANES)
|
||||
.map(|denominator_inverses| denominator_inverses.iter().copied().collect())
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn quotient_constants(
|
||||
sample_batches: &[ColumnSampleBatch],
|
||||
random_coeff: SecureField,
|
||||
domain: CircleDomain,
|
||||
) -> QuotientConstants {
|
||||
let _span = span!(Level::INFO, "Quotient constants").entered();
|
||||
let line_coeffs = column_line_coeffs(sample_batches, random_coeff);
|
||||
let batch_random_coeffs = batch_random_coeffs(sample_batches, random_coeff);
|
||||
let denominator_inverses = denominator_inverses(sample_batches, domain);
|
||||
QuotientConstants {
|
||||
line_coeffs,
|
||||
batch_random_coeffs,
|
||||
denominator_inverses,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::core::backend::simd::column::BaseColumn;
|
||||
use crate::core::backend::simd::SimdBackend;
|
||||
use crate::core::backend::{Column, CpuBackend};
|
||||
use crate::core::circle::SECURE_FIELD_CIRCLE_GEN;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::pcs::quotients::{ColumnSampleBatch, QuotientOps};
|
||||
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
|
||||
use crate::core::poly::BitReversedOrder;
|
||||
use crate::qm31;
|
||||
|
||||
#[test]
|
||||
fn test_accumulate_quotients() {
|
||||
const LOG_SIZE: u32 = 8;
|
||||
const LOG_BLOWUP_FACTOR: u32 = 1;
|
||||
let small_domain = CanonicCoset::new(LOG_SIZE).circle_domain();
|
||||
let domain = CanonicCoset::new(LOG_SIZE + LOG_BLOWUP_FACTOR).circle_domain();
|
||||
let e0: BaseColumn = (0..small_domain.size()).map(BaseField::from).collect();
|
||||
let e1: BaseColumn = (0..small_domain.size())
|
||||
.map(|i| BaseField::from(2 * i))
|
||||
.collect();
|
||||
let polys = vec![
|
||||
CircleEvaluation::<SimdBackend, BaseField, BitReversedOrder>::new(small_domain, e0)
|
||||
.interpolate(),
|
||||
CircleEvaluation::<SimdBackend, BaseField, BitReversedOrder>::new(small_domain, e1)
|
||||
.interpolate(),
|
||||
];
|
||||
let columns = vec![polys[0].evaluate(domain), polys[1].evaluate(domain)];
|
||||
let random_coeff = qm31!(1, 2, 3, 4);
|
||||
let a = polys[0].eval_at_point(SECURE_FIELD_CIRCLE_GEN);
|
||||
let b = polys[1].eval_at_point(SECURE_FIELD_CIRCLE_GEN);
|
||||
let samples = vec![ColumnSampleBatch {
|
||||
point: SECURE_FIELD_CIRCLE_GEN,
|
||||
columns_and_values: vec![(0, a), (1, b)],
|
||||
}];
|
||||
let cpu_columns = columns
|
||||
.iter()
|
||||
.map(|c| {
|
||||
CircleEvaluation::<CpuBackend, _, BitReversedOrder>::new(
|
||||
c.domain,
|
||||
c.values.to_cpu(),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let cpu_result = CpuBackend::accumulate_quotients(
|
||||
domain,
|
||||
&cpu_columns.iter().collect_vec(),
|
||||
random_coeff,
|
||||
&samples,
|
||||
LOG_BLOWUP_FACTOR,
|
||||
)
|
||||
.values
|
||||
.to_vec();
|
||||
|
||||
let res = SimdBackend::accumulate_quotients(
|
||||
domain,
|
||||
&columns.iter().collect_vec(),
|
||||
random_coeff,
|
||||
&samples,
|
||||
LOG_BLOWUP_FACTOR,
|
||||
)
|
||||
.values
|
||||
.to_vec();
|
||||
|
||||
assert_eq!(res, cpu_result);
|
||||
}
|
||||
}
|
||||
52
Stwo_wrapper/crates/prover/src/core/backend/simd/utils.rs
Normal file
52
Stwo_wrapper/crates/prover/src/core/backend/simd/utils.rs
Normal file
@ -0,0 +1,52 @@
|
||||
use std::simd::Swizzle;
|
||||
|
||||
/// Used with [`Swizzle::concat_swizzle`] to interleave the even values of two vectors.
|
||||
pub struct InterleaveEvens;
|
||||
|
||||
impl<const N: usize> Swizzle<N> for InterleaveEvens {
|
||||
const INDEX: [usize; N] = parity_interleave(false);
|
||||
}
|
||||
|
||||
/// Used with [`Swizzle::concat_swizzle`] to interleave the odd values of two vectors.
|
||||
pub struct InterleaveOdds;
|
||||
|
||||
impl<const N: usize> Swizzle<N> for InterleaveOdds {
|
||||
const INDEX: [usize; N] = parity_interleave(true);
|
||||
}
|
||||
|
||||
const fn parity_interleave<const N: usize>(odd: bool) -> [usize; N] {
|
||||
let mut res = [0; N];
|
||||
let mut i = 0;
|
||||
while i < N {
|
||||
res[i] = (i % 2) * N + (i / 2) * 2 + if odd { 1 } else { 0 };
|
||||
i += 1;
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::simd::{u32x4, Swizzle};
|
||||
|
||||
use super::{InterleaveEvens, InterleaveOdds};
|
||||
|
||||
#[test]
|
||||
fn interleave_evens() {
|
||||
let lo = u32x4::from_array([0, 1, 2, 3]);
|
||||
let hi = u32x4::from_array([4, 5, 6, 7]);
|
||||
|
||||
let res = InterleaveEvens::concat_swizzle(lo, hi);
|
||||
|
||||
assert_eq!(res, u32x4::from_array([0, 4, 2, 6]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interleave_odds() {
|
||||
let lo = u32x4::from_array([0, 1, 2, 3]);
|
||||
let hi = u32x4::from_array([4, 5, 6, 7]);
|
||||
|
||||
let res = InterleaveOdds::concat_swizzle(lo, hi);
|
||||
|
||||
assert_eq!(res, u32x4::from_array([1, 5, 3, 7]));
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,222 @@
|
||||
use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub};
|
||||
|
||||
use bytemuck::{Pod, Zeroable};
|
||||
use num_traits::{One, Zero};
|
||||
|
||||
use super::cm31::PackedCM31;
|
||||
use super::m31::{PackedM31, N_LANES};
|
||||
use super::qm31::PackedQM31;
|
||||
use crate::core::fields::cm31::CM31;
|
||||
use crate::core::fields::m31::{pow2147483645, M31};
|
||||
use crate::core::fields::qm31::QM31;
|
||||
use crate::core::fields::FieldExpOps;
|
||||
|
||||
pub const LOG_N_VERY_PACKED_ELEMS: u32 = 1;
|
||||
pub const N_VERY_PACKED_ELEMS: usize = 1 << LOG_N_VERY_PACKED_ELEMS;
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
#[repr(transparent)]
|
||||
pub struct Vectorized<A, const N: usize>(pub [A; N]);
|
||||
|
||||
impl<A, const N: usize> Vectorized<A, N> {
|
||||
pub fn from_fn<F>(cb: F) -> Self
|
||||
where
|
||||
F: FnMut(usize) -> A,
|
||||
{
|
||||
Vectorized(std::array::from_fn(cb))
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<A, const N: usize> Zeroable for Vectorized<A, N> {
|
||||
fn zeroed() -> Self {
|
||||
unsafe { core::mem::zeroed() }
|
||||
}
|
||||
}
|
||||
unsafe impl<A: Pod, const N: usize> Pod for Vectorized<A, N> {}
|
||||
|
||||
pub type VeryPackedM31 = Vectorized<PackedM31, N_VERY_PACKED_ELEMS>;
|
||||
pub type VeryPackedCM31 = Vectorized<PackedCM31, N_VERY_PACKED_ELEMS>;
|
||||
pub type VeryPackedQM31 = Vectorized<PackedQM31, N_VERY_PACKED_ELEMS>;
|
||||
pub type VeryPackedBaseField = VeryPackedM31;
|
||||
pub type VeryPackedSecureField = VeryPackedQM31;
|
||||
|
||||
impl VeryPackedM31 {
|
||||
pub fn broadcast(value: M31) -> Self {
|
||||
Self::from_fn(|_| PackedM31::broadcast(value))
|
||||
}
|
||||
|
||||
pub fn from_array(values: [M31; N_LANES * N_VERY_PACKED_ELEMS]) -> VeryPackedM31 {
|
||||
Self::from_fn(|i| {
|
||||
let start = i * N_LANES;
|
||||
let end = start + N_LANES;
|
||||
PackedM31::from_array(values[start..end].try_into().unwrap())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn to_array(&self) -> [M31; N_LANES * N_VERY_PACKED_ELEMS] {
|
||||
// Safety: We are transmuting &[A; N_VERY_PACKED_ELEMS] into &[i32; N_LANES *
|
||||
// N_VERY_PACKED_ELEMS] because we know that A contains [i32; N_LANES] and the
|
||||
// memory layout is contiguous.
|
||||
unsafe {
|
||||
std::slice::from_raw_parts(self.0.as_ptr() as *const M31, N_LANES * N_VERY_PACKED_ELEMS)
|
||||
.try_into()
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl VeryPackedCM31 {
|
||||
pub fn broadcast(value: CM31) -> Self {
|
||||
Self::from_fn(|_| PackedCM31::broadcast(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl VeryPackedQM31 {
|
||||
pub fn broadcast(value: QM31) -> Self {
|
||||
Self::from_fn(|_| PackedQM31::broadcast(value))
|
||||
}
|
||||
|
||||
pub fn from_very_packed_m31s([a, b, c, d]: [VeryPackedM31; 4]) -> Self {
|
||||
Self::from_fn(|i| PackedQM31::from_packed_m31s([a.0[i], b.0[i], c.0[i], d.0[i]]))
|
||||
}
|
||||
}
|
||||
impl From<M31> for VeryPackedM31 {
|
||||
fn from(v: M31) -> Self {
|
||||
Self::broadcast(v)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<VeryPackedM31> for VeryPackedQM31 {
|
||||
fn from(value: VeryPackedM31) -> Self {
|
||||
VeryPackedQM31::from_very_packed_m31s([
|
||||
value,
|
||||
VeryPackedM31::zero(),
|
||||
VeryPackedM31::zero(),
|
||||
VeryPackedM31::zero(),
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<QM31> for VeryPackedQM31 {
|
||||
fn from(value: QM31) -> Self {
|
||||
VeryPackedQM31::broadcast(value)
|
||||
}
|
||||
}
|
||||
|
||||
trait Scalar {}
|
||||
impl Scalar for M31 {}
|
||||
impl Scalar for CM31 {}
|
||||
impl Scalar for QM31 {}
|
||||
impl Scalar for PackedM31 {}
|
||||
impl Scalar for PackedCM31 {}
|
||||
impl Scalar for PackedQM31 {}
|
||||
|
||||
impl<A: Add<B> + Copy, B: Copy, const N: usize> Add<Vectorized<B, N>> for Vectorized<A, N> {
|
||||
type Output = Vectorized<A::Output, N>;
|
||||
|
||||
fn add(self, other: Vectorized<B, N>) -> Self::Output {
|
||||
Vectorized::from_fn(|i| self.0[i] + other.0[i])
|
||||
}
|
||||
}
|
||||
|
||||
impl<A: Add<B> + Copy, B: Scalar + Copy, const N: usize> Add<B> for Vectorized<A, N> {
|
||||
type Output = Vectorized<A::Output, N>;
|
||||
|
||||
fn add(self, other: B) -> Self::Output {
|
||||
Vectorized::from_fn(|i| self.0[i] + other)
|
||||
}
|
||||
}
|
||||
|
||||
impl<A: Sub<B> + Copy, B: Copy, const N: usize> Sub<Vectorized<B, N>> for Vectorized<A, N> {
|
||||
type Output = Vectorized<A::Output, N>;
|
||||
|
||||
fn sub(self, other: Vectorized<B, N>) -> Self::Output {
|
||||
Vectorized::from_fn(|i| self.0[i] - other.0[i])
|
||||
}
|
||||
}
|
||||
|
||||
impl<A: Sub<B> + Copy, B: Scalar + Copy, const N: usize> Sub<B> for Vectorized<A, N> {
|
||||
type Output = Vectorized<A::Output, N>;
|
||||
|
||||
fn sub(self, other: B) -> Self::Output {
|
||||
Vectorized::from_fn(|i| self.0[i] - other)
|
||||
}
|
||||
}
|
||||
|
||||
impl<A: Mul<B> + Copy, B: Copy, const N: usize> Mul<Vectorized<B, N>> for Vectorized<A, N> {
|
||||
type Output = Vectorized<A::Output, N>;
|
||||
|
||||
fn mul(self, other: Vectorized<B, N>) -> Self::Output {
|
||||
Vectorized::from_fn(|i| self.0[i] * other.0[i])
|
||||
}
|
||||
}
|
||||
|
||||
impl<A: Mul<B> + Copy, B: Scalar + Copy, const N: usize> Mul<B> for Vectorized<A, N> {
|
||||
type Output = Vectorized<A::Output, N>;
|
||||
|
||||
fn mul(self, other: B) -> Self::Output {
|
||||
Vectorized::from_fn(|i| self.0[i] * other)
|
||||
}
|
||||
}
|
||||
|
||||
impl<A: AddAssign<B> + Copy, B: Copy, const N: usize> AddAssign<Vectorized<B, N>>
|
||||
for Vectorized<A, N>
|
||||
{
|
||||
fn add_assign(&mut self, other: Vectorized<B, N>) {
|
||||
for i in 0..N {
|
||||
self.0[i] += other.0[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<A: AddAssign<B> + Copy, B: Scalar + Copy, const N: usize> AddAssign<B> for Vectorized<A, N> {
|
||||
fn add_assign(&mut self, other: B) {
|
||||
for i in 0..N {
|
||||
self.0[i] += other;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<A: MulAssign<B> + Copy, B: Copy, const N: usize> MulAssign<Vectorized<B, N>>
|
||||
for Vectorized<A, N>
|
||||
{
|
||||
fn mul_assign(&mut self, other: Vectorized<B, N>) {
|
||||
for i in 0..N {
|
||||
self.0[i] *= other.0[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<A: Neg + Copy, const N: usize> Neg for Vectorized<A, N> {
|
||||
type Output = Vectorized<A::Output, N>;
|
||||
|
||||
#[inline(always)]
|
||||
fn neg(self) -> Self::Output {
|
||||
Vectorized::from_fn(|i| self.0[i].neg())
|
||||
}
|
||||
}
|
||||
|
||||
impl<A: Zero + Copy, const N: usize> Zero for Vectorized<A, N> {
|
||||
fn zero() -> Self {
|
||||
Vectorized::from_fn(|_| A::zero())
|
||||
}
|
||||
|
||||
fn is_zero(&self) -> bool {
|
||||
self.0.iter().all(A::is_zero)
|
||||
}
|
||||
}
|
||||
|
||||
impl<A: One + Copy, const N: usize> One for Vectorized<A, N> {
|
||||
fn one() -> Self {
|
||||
Vectorized::from_fn(|_| A::one())
|
||||
}
|
||||
}
|
||||
|
||||
impl<A: FieldExpOps + Zero, const N: usize> FieldExpOps for Vectorized<A, N> {
|
||||
fn inverse(&self) -> Self {
|
||||
Vectorized::from_fn(|i| {
|
||||
assert!(!self.0[i].is_zero(), "0 has no inverse");
|
||||
pow2147483645(self.0[i])
|
||||
})
|
||||
}
|
||||
}
|
||||
186
Stwo_wrapper/crates/prover/src/core/channel/blake2s.rs
Normal file
186
Stwo_wrapper/crates/prover/src/core/channel/blake2s.rs
Normal file
@ -0,0 +1,186 @@
|
||||
use std::iter;
|
||||
|
||||
use super::{Channel, ChannelTime};
|
||||
use crate::core::fields::m31::{BaseField, N_BYTES_FELT, P};
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
|
||||
use crate::core::fields::IntoSlice;
|
||||
use crate::core::vcs::blake2_hash::{Blake2sHash, Blake2sHasher};
|
||||
use crate::core::vcs::blake2s_ref::compress;
|
||||
|
||||
pub const BLAKE_BYTES_PER_HASH: usize = 32;
|
||||
pub const FELTS_PER_HASH: usize = 8;
|
||||
|
||||
/// A channel that can be used to draw random elements from a [Blake2sHash] digest.
|
||||
#[derive(Default, Clone)]
|
||||
pub struct Blake2sChannel {
|
||||
digest: Blake2sHash,
|
||||
pub channel_time: ChannelTime,
|
||||
}
|
||||
|
||||
impl Blake2sChannel {
|
||||
pub fn digest(&self) -> Blake2sHash {
|
||||
self.digest
|
||||
}
|
||||
pub fn update_digest(&mut self, new_digest: Blake2sHash) {
|
||||
self.digest = new_digest;
|
||||
self.channel_time.inc_challenges();
|
||||
}
|
||||
/// Generates a uniform random vector of BaseField elements.
|
||||
fn draw_base_felts(&mut self) -> [BaseField; FELTS_PER_HASH] {
|
||||
// Repeats hashing with an increasing counter until getting a good result.
|
||||
// Retry probability for each round is ~ 2^(-28).
|
||||
loop {
|
||||
let u32s: [u32; FELTS_PER_HASH] = self
|
||||
.draw_random_bytes()
|
||||
.chunks_exact(N_BYTES_FELT) // 4 bytes per u32.
|
||||
.map(|chunk| u32::from_le_bytes(chunk.try_into().unwrap()))
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
// Retry if not all the u32 are in the range [0, 2P).
|
||||
if u32s.iter().all(|x| *x < 2 * P) {
|
||||
return u32s
|
||||
.into_iter()
|
||||
.map(|x| BaseField::reduce(x as u64))
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Channel for Blake2sChannel {
|
||||
const BYTES_PER_HASH: usize = BLAKE_BYTES_PER_HASH;
|
||||
|
||||
fn trailing_zeros(&self) -> u32 {
|
||||
u128::from_le_bytes(std::array::from_fn(|i| self.digest.0[i])).trailing_zeros()
|
||||
}
|
||||
|
||||
fn mix_felts(&mut self, felts: &[SecureField]) {
|
||||
let mut hasher = Blake2sHasher::new();
|
||||
hasher.update(self.digest.as_ref());
|
||||
hasher.update(IntoSlice::<u8>::into_slice(felts));
|
||||
|
||||
self.update_digest(hasher.finalize());
|
||||
}
|
||||
|
||||
fn mix_nonce(&mut self, nonce: u64) {
|
||||
let digest: [u32; 8] = unsafe { std::mem::transmute(self.digest) };
|
||||
let mut msg = [0; 16];
|
||||
msg[0] = nonce as u32;
|
||||
msg[1] = (nonce >> 32) as u32;
|
||||
let res = compress(std::array::from_fn(|i| digest[i]), msg, 0, 0, 0, 0);
|
||||
|
||||
self.update_digest(unsafe { std::mem::transmute(res) });
|
||||
}
|
||||
|
||||
fn draw_felt(&mut self) -> SecureField {
|
||||
let felts: [BaseField; FELTS_PER_HASH] = self.draw_base_felts();
|
||||
SecureField::from_m31_array(felts[..SECURE_EXTENSION_DEGREE].try_into().unwrap())
|
||||
}
|
||||
|
||||
fn draw_felts(&mut self, n_felts: usize) -> Vec<SecureField> {
|
||||
let mut felts = iter::from_fn(|| Some(self.draw_base_felts())).flatten();
|
||||
let secure_felts = iter::from_fn(|| {
|
||||
Some(SecureField::from_m31_array([
|
||||
felts.next()?,
|
||||
felts.next()?,
|
||||
felts.next()?,
|
||||
felts.next()?,
|
||||
]))
|
||||
});
|
||||
secure_felts.take(n_felts).collect()
|
||||
}
|
||||
|
||||
fn draw_random_bytes(&mut self) -> Vec<u8> {
|
||||
let mut hash_input = self.digest.as_ref().to_vec();
|
||||
|
||||
// Pad the counter to 32 bytes.
|
||||
let mut padded_counter = [0; BLAKE_BYTES_PER_HASH];
|
||||
let counter_bytes = self.channel_time.n_sent.to_le_bytes();
|
||||
padded_counter[0..counter_bytes.len()].copy_from_slice(&counter_bytes);
|
||||
|
||||
hash_input.extend_from_slice(&padded_counter);
|
||||
|
||||
// TODO(spapini): Are we worried about this drawing hash colliding with mix_digest?
|
||||
|
||||
self.channel_time.inc_sent();
|
||||
Blake2sHasher::hash(&hash_input).into()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::BTreeSet;
|
||||
|
||||
use crate::core::channel::blake2s::Blake2sChannel;
|
||||
use crate::core::channel::Channel;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::m31;
|
||||
|
||||
#[test]
|
||||
fn test_channel_time() {
|
||||
let mut channel = Blake2sChannel::default();
|
||||
|
||||
assert_eq!(channel.channel_time.n_challenges, 0);
|
||||
assert_eq!(channel.channel_time.n_sent, 0);
|
||||
|
||||
channel.draw_random_bytes();
|
||||
assert_eq!(channel.channel_time.n_challenges, 0);
|
||||
assert_eq!(channel.channel_time.n_sent, 1);
|
||||
|
||||
channel.draw_felts(9);
|
||||
assert_eq!(channel.channel_time.n_challenges, 0);
|
||||
assert_eq!(channel.channel_time.n_sent, 6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_draw_random_bytes() {
|
||||
let mut channel = Blake2sChannel::default();
|
||||
|
||||
let first_random_bytes = channel.draw_random_bytes();
|
||||
|
||||
// Assert that next random bytes are different.
|
||||
assert_ne!(first_random_bytes, channel.draw_random_bytes());
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_draw_felt() {
|
||||
let mut channel = Blake2sChannel::default();
|
||||
|
||||
let first_random_felt = channel.draw_felt();
|
||||
|
||||
// Assert that next random felt is different.
|
||||
assert_ne!(first_random_felt, channel.draw_felt());
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_draw_felts() {
|
||||
let mut channel = Blake2sChannel::default();
|
||||
|
||||
let mut random_felts = channel.draw_felts(5);
|
||||
random_felts.extend(channel.draw_felts(4));
|
||||
|
||||
// Assert that all the random felts are unique.
|
||||
assert_eq!(
|
||||
random_felts.len(),
|
||||
random_felts.iter().collect::<BTreeSet<_>>().len()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_mix_felts() {
|
||||
let mut channel = Blake2sChannel::default();
|
||||
let initial_digest = channel.digest;
|
||||
let felts: Vec<SecureField> = (0..2)
|
||||
.map(|i| SecureField::from(m31!(i + 1923782)))
|
||||
.collect();
|
||||
|
||||
channel.mix_felts(felts.as_slice());
|
||||
|
||||
assert_ne!(initial_digest, channel.digest);
|
||||
}
|
||||
}
|
||||
57
Stwo_wrapper/crates/prover/src/core/channel/mod.rs
Normal file
57
Stwo_wrapper/crates/prover/src/core/channel/mod.rs
Normal file
@ -0,0 +1,57 @@
|
||||
use super::fields::qm31::SecureField;
|
||||
use super::vcs::ops::MerkleHasher;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
mod poseidon252;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub use poseidon252::Poseidon252Channel;
|
||||
|
||||
mod blake2s;
|
||||
pub use blake2s::Blake2sChannel;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
mod poseidon_bls;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub use poseidon_bls::PoseidonBLSChannel;
|
||||
|
||||
pub const EXTENSION_FELTS_PER_HASH: usize = 2;
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct ChannelTime {
|
||||
pub n_challenges: usize,
|
||||
n_sent: usize,
|
||||
}
|
||||
|
||||
impl ChannelTime {
|
||||
fn inc_sent(&mut self) {
|
||||
self.n_sent += 1;
|
||||
}
|
||||
|
||||
fn inc_challenges(&mut self) {
|
||||
self.n_challenges += 1;
|
||||
self.n_sent = 0;
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Channel: Default + Clone {
|
||||
const BYTES_PER_HASH: usize;
|
||||
|
||||
fn trailing_zeros(&self) -> u32;
|
||||
|
||||
// Mix functions.
|
||||
fn mix_felts(&mut self, felts: &[SecureField]);
|
||||
fn mix_nonce(&mut self, nonce: u64);
|
||||
|
||||
// Draw functions.
|
||||
fn draw_felt(&mut self) -> SecureField;
|
||||
/// Generates a uniform random vector of SecureField elements.
|
||||
fn draw_felts(&mut self, n_felts: usize) -> Vec<SecureField>;
|
||||
/// Returns a vector of random bytes of length `BYTES_PER_HASH`.
|
||||
fn draw_random_bytes(&mut self) -> Vec<u8>;
|
||||
}
|
||||
|
||||
pub trait MerkleChannel: Default {
|
||||
type C: Channel;
|
||||
type H: MerkleHasher;
|
||||
fn mix_root(channel: &mut Self::C, root: <Self::H as MerkleHasher>::Hash);
|
||||
}
|
||||
190
Stwo_wrapper/crates/prover/src/core/channel/poseidon252.rs
Normal file
190
Stwo_wrapper/crates/prover/src/core/channel/poseidon252.rs
Normal file
@ -0,0 +1,190 @@
|
||||
use std::iter;
|
||||
|
||||
use starknet_crypto::{poseidon_hash, poseidon_hash_many};
|
||||
use starknet_ff::FieldElement as FieldElement252;
|
||||
|
||||
use super::{Channel, ChannelTime};
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
|
||||
|
||||
pub const BYTES_PER_FELT252: usize = 31;
|
||||
pub const FELTS_PER_HASH: usize = 8;
|
||||
|
||||
/// A channel that can be used to draw random elements from a Poseidon252 hash.
|
||||
#[derive(Clone, Default)]
|
||||
pub struct Poseidon252Channel {
|
||||
digest: FieldElement252,
|
||||
pub channel_time: ChannelTime,
|
||||
}
|
||||
|
||||
impl Poseidon252Channel {
|
||||
pub fn digest(&self) -> FieldElement252 {
|
||||
self.digest
|
||||
}
|
||||
pub fn update_digest(&mut self, new_digest: FieldElement252) {
|
||||
self.digest = new_digest;
|
||||
self.channel_time.inc_challenges();
|
||||
}
|
||||
fn draw_felt252(&mut self) -> FieldElement252 {
|
||||
let res = poseidon_hash(self.digest, self.channel_time.n_sent.into());
|
||||
self.channel_time.inc_sent();
|
||||
res
|
||||
}
|
||||
|
||||
// TODO(spapini): Understand if we really need uniformity here.
|
||||
/// Generates a close-to uniform random vector of BaseField elements.
|
||||
fn draw_base_felts(&mut self) -> [BaseField; 8] {
|
||||
let shift = (1u64 << 31).into();
|
||||
|
||||
let mut cur = self.draw_felt252();
|
||||
let u32s: [u32; 8] = std::array::from_fn(|_| {
|
||||
let next = cur.floor_div(shift);
|
||||
let res = cur - next * shift;
|
||||
cur = next;
|
||||
res.try_into().unwrap()
|
||||
});
|
||||
|
||||
u32s.into_iter()
|
||||
.map(|x| BaseField::reduce(x as u64))
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl Channel for Poseidon252Channel {
|
||||
const BYTES_PER_HASH: usize = BYTES_PER_FELT252;
|
||||
|
||||
fn trailing_zeros(&self) -> u32 {
|
||||
let bytes = self.digest.to_bytes_be();
|
||||
u128::from_le_bytes(std::array::from_fn(|i| bytes[i])).trailing_zeros()
|
||||
}
|
||||
|
||||
// TODO(spapini): Optimize.
|
||||
fn mix_felts(&mut self, felts: &[SecureField]) {
|
||||
let shift = (1u64 << 31).into();
|
||||
let mut res = Vec::with_capacity(felts.len() / 2 + 2);
|
||||
res.push(self.digest);
|
||||
for chunk in felts.chunks(2) {
|
||||
res.push(
|
||||
chunk
|
||||
.iter()
|
||||
.flat_map(|x| x.to_m31_array())
|
||||
.fold(FieldElement252::default(), |cur, y| {
|
||||
cur * shift + y.0.into()
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
// TODO(spapini): do we need length padding?
|
||||
self.update_digest(poseidon_hash_many(&res));
|
||||
}
|
||||
|
||||
fn mix_nonce(&mut self, nonce: u64) {
|
||||
self.update_digest(poseidon_hash(self.digest, nonce.into()));
|
||||
}
|
||||
|
||||
fn draw_felt(&mut self) -> SecureField {
|
||||
let felts: [BaseField; FELTS_PER_HASH] = self.draw_base_felts();
|
||||
SecureField::from_m31_array(felts[..SECURE_EXTENSION_DEGREE].try_into().unwrap())
|
||||
}
|
||||
|
||||
fn draw_felts(&mut self, n_felts: usize) -> Vec<SecureField> {
|
||||
let mut felts = iter::from_fn(|| Some(self.draw_base_felts())).flatten();
|
||||
let secure_felts = iter::from_fn(|| {
|
||||
Some(SecureField::from_m31_array([
|
||||
felts.next()?,
|
||||
felts.next()?,
|
||||
felts.next()?,
|
||||
felts.next()?,
|
||||
]))
|
||||
});
|
||||
secure_felts.take(n_felts).collect()
|
||||
}
|
||||
|
||||
fn draw_random_bytes(&mut self) -> Vec<u8> {
|
||||
let shift = (1u64 << 8).into();
|
||||
let mut cur = self.draw_felt252();
|
||||
let bytes: [u8; 31] = std::array::from_fn(|_| {
|
||||
let next = cur.floor_div(shift);
|
||||
let res = cur - next * shift;
|
||||
cur = next;
|
||||
res.try_into().unwrap()
|
||||
});
|
||||
bytes.to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::BTreeSet;
|
||||
|
||||
use crate::core::channel::poseidon252::Poseidon252Channel;
|
||||
use crate::core::channel::Channel;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::m31;
|
||||
|
||||
#[test]
|
||||
fn test_channel_time() {
|
||||
let mut channel = Poseidon252Channel::default();
|
||||
|
||||
assert_eq!(channel.channel_time.n_challenges, 0);
|
||||
assert_eq!(channel.channel_time.n_sent, 0);
|
||||
|
||||
channel.draw_random_bytes();
|
||||
assert_eq!(channel.channel_time.n_challenges, 0);
|
||||
assert_eq!(channel.channel_time.n_sent, 1);
|
||||
|
||||
channel.draw_felts(9);
|
||||
assert_eq!(channel.channel_time.n_challenges, 0);
|
||||
assert_eq!(channel.channel_time.n_sent, 6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_draw_random_bytes() {
|
||||
let mut channel = Poseidon252Channel::default();
|
||||
|
||||
let first_random_bytes = channel.draw_random_bytes();
|
||||
|
||||
// Assert that next random bytes are different.
|
||||
assert_ne!(first_random_bytes, channel.draw_random_bytes());
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_draw_felt() {
|
||||
let mut channel = Poseidon252Channel::default();
|
||||
|
||||
let first_random_felt = channel.draw_felt();
|
||||
|
||||
// Assert that next random felt is different.
|
||||
assert_ne!(first_random_felt, channel.draw_felt());
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_draw_felts() {
|
||||
let mut channel = Poseidon252Channel::default();
|
||||
|
||||
let mut random_felts = channel.draw_felts(5);
|
||||
random_felts.extend(channel.draw_felts(4));
|
||||
|
||||
// Assert that all the random felts are unique.
|
||||
assert_eq!(
|
||||
random_felts.len(),
|
||||
random_felts.iter().collect::<BTreeSet<_>>().len()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_mix_felts() {
|
||||
let mut channel = Poseidon252Channel::default();
|
||||
let initial_digest = channel.digest;
|
||||
let felts: Vec<SecureField> = (0..2)
|
||||
.map(|i| SecureField::from(m31!(i + 1923782)))
|
||||
.collect();
|
||||
|
||||
channel.mix_felts(felts.as_slice());
|
||||
|
||||
assert_ne!(initial_digest, channel.digest);
|
||||
}
|
||||
}
|
||||
590
Stwo_wrapper/crates/prover/src/core/channel/poseidon_bls.rs
Normal file
590
Stwo_wrapper/crates/prover/src/core/channel/poseidon_bls.rs
Normal file
@ -0,0 +1,590 @@
|
||||
use std::iter;
|
||||
|
||||
use ark_bls12_381::Fr as BlsFr;
|
||||
use ark_ff::{BigInteger, Field, PrimeField};
|
||||
use crypto_bigint::{Encoding, NonZero, U256};
|
||||
|
||||
use super::{Channel, ChannelTime};
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
|
||||
|
||||
pub const BYTES_PER_FELT252: usize = 32;
|
||||
pub const FELTS_PER_HASH: usize = 8;
|
||||
|
||||
//Optimize constant to be real constants (no conversion) and merge duplicated code in VCS poseidono
|
||||
fn poseidon_comp_consts(idx: usize) -> BlsFr {
|
||||
match idx {
|
||||
0 => BlsFr::from_be_bytes_mod_order(&[
|
||||
111, 0, 122, 85, 17, 86, 179, 164, 73, 228, 73, 54, 183, 192, 147, 100, 74, 14, 211,
|
||||
63, 51, 234, 204, 198, 40, 233, 66, 232, 54, 193, 168, 117,
|
||||
]),
|
||||
1 => BlsFr::from_be_bytes_mod_order(&[
|
||||
54, 13, 116, 112, 97, 30, 71, 61, 53, 63, 98, 143, 118, 209, 16, 243, 78, 113, 22, 47,
|
||||
49, 0, 59, 112, 87, 83, 140, 37, 150, 66, 99, 3,
|
||||
]),
|
||||
2 => BlsFr::from_be_bytes_mod_order(&[
|
||||
75, 95, 236, 58, 160, 115, 223, 68, 1, 144, 145, 240, 7, 164, 76, 169, 150, 72, 73,
|
||||
101, 247, 3, 109, 206, 62, 157, 9, 119, 237, 205, 192, 246,
|
||||
]),
|
||||
3 => BlsFr::from_be_bytes_mod_order(&[
|
||||
103, 207, 24, 104, 175, 99, 150, 192, 184, 76, 206, 113, 94, 83, 159, 132, 158, 6, 205,
|
||||
28, 56, 58, 197, 176, 97, 0, 199, 107, 204, 151, 58, 17,
|
||||
]),
|
||||
4 => BlsFr::from_be_bytes_mod_order(&[
|
||||
85, 93, 180, 209, 220, 237, 129, 159, 93, 61, 231, 15, 222, 131, 241, 199, 211, 232,
|
||||
201, 137, 104, 229, 22, 162, 58, 119, 26, 92, 156, 130, 87, 170,
|
||||
]),
|
||||
5 => BlsFr::from_be_bytes_mod_order(&[
|
||||
43, 171, 148, 215, 174, 34, 45, 19, 93, 195, 198, 197, 254, 191, 170, 49, 73, 8, 172,
|
||||
47, 18, 235, 224, 111, 189, 183, 66, 19, 191, 99, 24, 139,
|
||||
]),
|
||||
6 => BlsFr::from_be_bytes_mod_order(&[
|
||||
102, 244, 75, 229, 41, 102, 130, 196, 250, 120, 130, 121, 157, 109, 208, 73, 182, 215,
|
||||
210, 201, 80, 204, 249, 140, 242, 229, 13, 109, 30, 187, 119, 194,
|
||||
]),
|
||||
7 => BlsFr::from_be_bytes_mod_order(&[
|
||||
21, 12, 147, 254, 246, 82, 251, 28, 43, 240, 62, 26, 41, 170, 135, 31, 239, 119, 231,
|
||||
215, 54, 118, 108, 93, 9, 57, 217, 39, 83, 204, 93, 200,
|
||||
]),
|
||||
8 => BlsFr::from_be_bytes_mod_order(&[
|
||||
50, 112, 102, 30, 104, 146, 139, 58, 149, 93, 85, 219, 86, 220, 87, 193, 3, 204, 10,
|
||||
96, 20, 30, 137, 78, 20, 37, 157, 206, 83, 119, 130, 178,
|
||||
]),
|
||||
9 => BlsFr::from_be_bytes_mod_order(&[
|
||||
7, 63, 17, 111, 4, 18, 46, 37, 160, 183, 175, 228, 226, 5, 114, 153, 180, 7, 195, 112,
|
||||
242, 181, 161, 204, 206, 159, 185, 255, 195, 69, 175, 179,
|
||||
]),
|
||||
10 => BlsFr::from_be_bytes_mod_order(&[
|
||||
64, 159, 218, 34, 85, 140, 254, 77, 61, 216, 220, 226, 79, 105, 231, 111, 140, 42, 174,
|
||||
177, 221, 15, 9, 214, 94, 101, 76, 113, 243, 42, 162, 63,
|
||||
]),
|
||||
11 => BlsFr::from_be_bytes_mod_order(&[
|
||||
42, 50, 236, 92, 78, 229, 177, 131, 122, 255, 208, 156, 31, 83, 245, 253, 85, 201, 205,
|
||||
32, 97, 174, 147, 202, 142, 186, 215, 111, 199, 21, 84, 216,
|
||||
]),
|
||||
12 => BlsFr::from_be_bytes_mod_order(&[
|
||||
88, 72, 235, 235, 89, 35, 233, 37, 85, 183, 18, 79, 255, 186, 93, 107, 213, 113, 198,
|
||||
249, 132, 25, 94, 185, 207, 211, 163, 232, 235, 85, 177, 212,
|
||||
]),
|
||||
13 => BlsFr::from_be_bytes_mod_order(&[
|
||||
39, 3, 38, 238, 3, 157, 241, 158, 101, 30, 44, 252, 116, 6, 40, 202, 99, 77, 36, 252,
|
||||
110, 37, 89, 242, 45, 140, 203, 226, 146, 239, 238, 173,
|
||||
]),
|
||||
14 => BlsFr::from_be_bytes_mod_order(&[
|
||||
39, 198, 100, 42, 198, 51, 188, 102, 220, 16, 15, 231, 252, 250, 84, 145, 138, 248,
|
||||
149, 188, 224, 18, 241, 130, 160, 104, 252, 55, 193, 130, 226, 116,
|
||||
]),
|
||||
15 => BlsFr::from_be_bytes_mod_order(&[
|
||||
27, 223, 216, 176, 20, 1, 199, 10, 210, 127, 87, 57, 105, 137, 18, 157, 113, 14, 31,
|
||||
182, 171, 151, 106, 69, 156, 161, 134, 130, 226, 109, 127, 249,
|
||||
]),
|
||||
16 => BlsFr::from_be_bytes_mod_order(&[
|
||||
73, 27, 155, 166, 152, 59, 207, 159, 5, 254, 71, 148, 173, 180, 74, 48, 135, 155, 248,
|
||||
40, 150, 98, 225, 245, 125, 144, 246, 114, 65, 78, 138, 74,
|
||||
]),
|
||||
17 => BlsFr::from_be_bytes_mod_order(&[
|
||||
22, 42, 20, 198, 47, 154, 137, 184, 20, 185, 214, 169, 200, 77, 214, 120, 244, 246,
|
||||
251, 63, 144, 84, 211, 115, 200, 50, 216, 36, 38, 26, 53, 234,
|
||||
]),
|
||||
18 => BlsFr::from_be_bytes_mod_order(&[
|
||||
45, 25, 62, 15, 118, 222, 88, 107, 42, 246, 247, 158, 49, 39, 254, 234, 172, 10, 31,
|
||||
199, 30, 44, 240, 192, 247, 152, 36, 102, 123, 91, 107, 236,
|
||||
]),
|
||||
19 => BlsFr::from_be_bytes_mod_order(&[
|
||||
70, 239, 216, 169, 162, 98, 214, 216, 253, 201, 202, 92, 4, 176, 152, 47, 36, 221, 204,
|
||||
110, 152, 99, 136, 90, 106, 115, 42, 57, 6, 160, 123, 149,
|
||||
]),
|
||||
20 => BlsFr::from_be_bytes_mod_order(&[
|
||||
80, 151, 23, 224, 194, 0, 227, 201, 45, 141, 202, 41, 115, 179, 219, 69, 240, 120, 130,
|
||||
148, 53, 26, 208, 122, 231, 92, 187, 120, 6, 147, 167, 152,
|
||||
]),
|
||||
21 => BlsFr::from_be_bytes_mod_order(&[
|
||||
114, 153, 178, 132, 100, 168, 201, 79, 185, 212, 223, 97, 56, 15, 57, 192, 220, 169,
|
||||
194, 192, 20, 17, 135, 137, 226, 39, 37, 40, 32, 240, 27, 252,
|
||||
]),
|
||||
22 => BlsFr::from_be_bytes_mod_order(&[
|
||||
4, 76, 163, 204, 74, 133, 215, 59, 129, 105, 110, 241, 16, 78, 103, 79, 79, 239, 248,
|
||||
41, 132, 153, 15, 248, 93, 11, 245, 141, 200, 164, 170, 148,
|
||||
]),
|
||||
23 => BlsFr::from_be_bytes_mod_order(&[
|
||||
28, 186, 242, 179, 113, 218, 198, 168, 29, 4, 83, 65, 109, 62, 35, 92, 184, 217, 226,
|
||||
212, 243, 20, 244, 111, 97, 152, 120, 95, 12, 214, 185, 175,
|
||||
]),
|
||||
24 => BlsFr::from_be_bytes_mod_order(&[
|
||||
29, 91, 39, 119, 105, 44, 32, 91, 14, 108, 73, 208, 97, 182, 181, 244, 41, 60, 74, 176,
|
||||
56, 253, 187, 220, 52, 62, 7, 97, 15, 63, 237, 229,
|
||||
]),
|
||||
25 => BlsFr::from_be_bytes_mod_order(&[
|
||||
86, 174, 124, 122, 82, 147, 189, 194, 62, 133, 225, 105, 140, 129, 199, 127, 138, 216,
|
||||
140, 75, 51, 165, 120, 4, 55, 173, 4, 124, 110, 219, 89, 186,
|
||||
]),
|
||||
26 => BlsFr::from_be_bytes_mod_order(&[
|
||||
46, 155, 219, 186, 61, 211, 75, 255, 170, 48, 83, 91, 221, 116, 154, 126, 6, 169, 173,
|
||||
176, 193, 230, 249, 98, 246, 14, 151, 27, 141, 115, 176, 79,
|
||||
]),
|
||||
27 => BlsFr::from_be_bytes_mod_order(&[
|
||||
45, 225, 24, 134, 177, 128, 17, 202, 139, 213, 186, 227, 105, 105, 41, 159, 222, 64,
|
||||
251, 226, 109, 4, 123, 5, 3, 90, 19, 102, 31, 34, 65, 139,
|
||||
]),
|
||||
28 => BlsFr::from_be_bytes_mod_order(&[
|
||||
46, 7, 222, 23, 128, 184, 167, 13, 13, 91, 74, 63, 24, 65, 220, 216, 42, 185, 57, 92,
|
||||
68, 155, 233, 71, 188, 153, 136, 132, 186, 150, 167, 33,
|
||||
]),
|
||||
29 => BlsFr::from_be_bytes_mod_order(&[
|
||||
15, 105, 241, 133, 77, 32, 202, 12, 187, 219, 99, 219, 213, 45, 173, 22, 37, 4, 64,
|
||||
169, 157, 107, 138, 243, 130, 94, 76, 43, 183, 73, 37, 202,
|
||||
]),
|
||||
30 => BlsFr::from_be_bytes_mod_order(&[
|
||||
93, 201, 135, 49, 142, 110, 89, 193, 175, 184, 123, 101, 93, 213, 140, 193, 210, 46,
|
||||
81, 58, 5, 131, 140, 212, 88, 93, 4, 177, 53, 185, 87, 202,
|
||||
]),
|
||||
31 => BlsFr::from_be_bytes_mod_order(&[
|
||||
72, 183, 37, 117, 133, 113, 201, 223, 108, 1, 220, 99, 154, 133, 240, 114, 151, 105,
|
||||
107, 27, 182, 120, 99, 58, 41, 220, 145, 222, 149, 239, 83, 246,
|
||||
]),
|
||||
32 => BlsFr::from_be_bytes_mod_order(&[
|
||||
94, 86, 94, 8, 192, 130, 16, 153, 37, 107, 86, 73, 14, 174, 225, 213, 115, 175, 209,
|
||||
11, 182, 209, 125, 19, 202, 78, 92, 97, 27, 42, 55, 24,
|
||||
]),
|
||||
33 => BlsFr::from_be_bytes_mod_order(&[
|
||||
46, 177, 178, 84, 23, 254, 23, 103, 13, 19, 93, 198, 57, 251, 9, 164, 108, 229, 17, 53,
|
||||
7, 249, 109, 233, 129, 108, 5, 148, 34, 220, 112, 94,
|
||||
]),
|
||||
34 => BlsFr::from_be_bytes_mod_order(&[
|
||||
17, 92, 208, 160, 100, 60, 251, 152, 140, 36, 203, 68, 195, 250, 180, 138, 255, 54,
|
||||
198, 97, 210, 108, 196, 45, 184, 177, 189, 244, 149, 59, 216, 44,
|
||||
]),
|
||||
35 => BlsFr::from_be_bytes_mod_order(&[
|
||||
38, 202, 41, 63, 123, 44, 70, 45, 6, 109, 115, 120, 185, 153, 134, 139, 187, 87, 221,
|
||||
241, 78, 15, 149, 138, 222, 128, 22, 18, 49, 29, 4, 205,
|
||||
]),
|
||||
36 => BlsFr::from_be_bytes_mod_order(&[
|
||||
65, 71, 64, 13, 142, 26, 172, 207, 49, 26, 107, 91, 118, 32, 17, 171, 62, 69, 50, 110,
|
||||
77, 75, 157, 226, 105, 146, 129, 107, 153, 197, 40, 172,
|
||||
]),
|
||||
37 => BlsFr::from_be_bytes_mod_order(&[
|
||||
107, 13, 183, 220, 204, 75, 161, 178, 104, 246, 189, 204, 77, 55, 40, 72, 212, 167, 41,
|
||||
118, 194, 104, 234, 48, 81, 154, 47, 115, 230, 219, 77, 85,
|
||||
]),
|
||||
38 => BlsFr::from_be_bytes_mod_order(&[
|
||||
23, 191, 27, 147, 196, 199, 224, 26, 42, 131, 10, 161, 98, 65, 44, 217, 15, 22, 11,
|
||||
249, 247, 30, 150, 127, 245, 32, 157, 20, 178, 72, 32, 202,
|
||||
]),
|
||||
39 => BlsFr::from_be_bytes_mod_order(&[
|
||||
75, 67, 28, 217, 239, 237, 188, 148, 207, 30, 202, 111, 158, 156, 24, 57, 208, 230,
|
||||
106, 139, 255, 168, 200, 70, 76, 172, 129, 163, 157, 60, 248, 241,
|
||||
]),
|
||||
40 => BlsFr::from_be_bytes_mod_order(&[
|
||||
53, 180, 26, 122, 196, 243, 197, 113, 162, 79, 132, 86, 54, 156, 133, 223, 224, 60, 3,
|
||||
84, 189, 140, 253, 56, 5, 200, 111, 46, 125, 194, 147, 197,
|
||||
]),
|
||||
41 => BlsFr::from_be_bytes_mod_order(&[
|
||||
59, 20, 128, 8, 5, 35, 196, 57, 67, 89, 39, 153, 72, 73, 190, 169, 100, 225, 77, 59,
|
||||
235, 45, 221, 222, 114, 172, 21, 106, 244, 53, 208, 158,
|
||||
]),
|
||||
42 => BlsFr::from_be_bytes_mod_order(&[
|
||||
44, 198, 129, 0, 49, 220, 27, 13, 73, 80, 133, 109, 201, 7, 213, 117, 8, 226, 134, 68,
|
||||
42, 45, 62, 178, 39, 22, 24, 216, 116, 177, 76, 109,
|
||||
]),
|
||||
43 => BlsFr::from_be_bytes_mod_order(&[
|
||||
111, 65, 65, 200, 64, 28, 90, 57, 91, 166, 121, 14, 253, 113, 199, 12, 4, 175, 234, 6,
|
||||
195, 201, 40, 38, 188, 171, 221, 92, 181, 71, 125, 81,
|
||||
]),
|
||||
44 => BlsFr::from_be_bytes_mod_order(&[
|
||||
37, 189, 187, 237, 161, 189, 232, 193, 5, 150, 24, 226, 175, 210, 239, 153, 158, 81,
|
||||
122, 169, 59, 120, 52, 29, 145, 243, 24, 192, 159, 12, 181, 102,
|
||||
]),
|
||||
45 => BlsFr::from_be_bytes_mod_order(&[
|
||||
57, 42, 74, 135, 88, 224, 110, 232, 185, 95, 51, 194, 93, 222, 138, 192, 42, 94, 208,
|
||||
162, 123, 97, 146, 108, 198, 49, 52, 135, 7, 63, 127, 123,
|
||||
]),
|
||||
46 => BlsFr::from_be_bytes_mod_order(&[
|
||||
39, 42, 85, 135, 138, 8, 68, 43, 154, 166, 17, 31, 77, 224, 9, 72, 94, 106, 111, 209,
|
||||
93, 184, 147, 101, 231, 187, 206, 240, 46, 181, 134, 108,
|
||||
]),
|
||||
47 => BlsFr::from_be_bytes_mod_order(&[
|
||||
99, 30, 193, 214, 210, 141, 217, 232, 36, 238, 137, 163, 7, 48, 174, 247, 171, 70, 58,
|
||||
207, 201, 209, 132, 179, 85, 170, 5, 253, 105, 56, 234, 181,
|
||||
]),
|
||||
48 => BlsFr::from_be_bytes_mod_order(&[
|
||||
78, 182, 253, 161, 15, 208, 251, 222, 2, 199, 68, 155, 251, 221, 195, 91, 205, 130, 37,
|
||||
231, 229, 195, 131, 58, 8, 24, 161, 0, 64, 157, 198, 242,
|
||||
]),
|
||||
49 => BlsFr::from_be_bytes_mod_order(&[
|
||||
45, 91, 48, 139, 12, 240, 44, 223, 239, 161, 60, 78, 96, 226, 98, 57, 166, 235, 186, 1,
|
||||
22, 148, 221, 18, 155, 146, 91, 60, 91, 33, 224, 226,
|
||||
]),
|
||||
50 => BlsFr::from_be_bytes_mod_order(&[
|
||||
22, 84, 159, 198, 175, 47, 59, 114, 221, 93, 41, 61, 114, 226, 229, 242, 68, 223, 244,
|
||||
47, 24, 180, 108, 86, 239, 56, 197, 124, 49, 22, 115, 172,
|
||||
]),
|
||||
51 => BlsFr::from_be_bytes_mod_order(&[
|
||||
66, 51, 38, 119, 255, 53, 156, 94, 141, 184, 54, 217, 245, 251, 84, 130, 46, 57, 189,
|
||||
94, 34, 52, 11, 185, 186, 151, 91, 161, 169, 43, 227, 130,
|
||||
]),
|
||||
52 => BlsFr::from_be_bytes_mod_order(&[
|
||||
73, 215, 210, 192, 180, 73, 229, 23, 155, 197, 204, 195, 180, 76, 96, 117, 217, 132,
|
||||
155, 86, 16, 70, 95, 9, 234, 114, 93, 220, 151, 114, 58, 148,
|
||||
]),
|
||||
53 => BlsFr::from_be_bytes_mod_order(&[
|
||||
100, 194, 15, 185, 13, 122, 0, 56, 49, 117, 124, 196, 198, 34, 111, 110, 73, 133, 252,
|
||||
158, 203, 65, 107, 159, 104, 76, 160, 53, 29, 150, 121, 4,
|
||||
]),
|
||||
54 => BlsFr::from_be_bytes_mod_order(&[
|
||||
89, 207, 244, 13, 232, 59, 82, 180, 27, 196, 67, 215, 151, 149, 16, 215, 113, 201, 64,
|
||||
185, 117, 140, 168, 32, 254, 115, 181, 200, 213, 88, 9, 52,
|
||||
]),
|
||||
55 => BlsFr::from_be_bytes_mod_order(&[
|
||||
83, 219, 39, 49, 115, 12, 57, 176, 78, 221, 135, 95, 227, 183, 200, 130, 128, 130, 133,
|
||||
205, 188, 98, 29, 122, 244, 248, 13, 213, 62, 187, 113, 176,
|
||||
]),
|
||||
56 => BlsFr::from_be_bytes_mod_order(&[
|
||||
27, 16, 187, 122, 130, 175, 206, 57, 250, 105, 195, 162, 173, 82, 247, 109, 118, 57,
|
||||
130, 101, 52, 66, 3, 17, 155, 113, 38, 217, 180, 104, 96, 223,
|
||||
]),
|
||||
57 => BlsFr::from_be_bytes_mod_order(&[
|
||||
86, 27, 96, 18, 214, 102, 191, 225, 121, 196, 221, 127, 132, 205, 209, 83, 21, 150,
|
||||
211, 170, 199, 197, 112, 12, 235, 49, 159, 145, 4, 106, 99, 201,
|
||||
]),
|
||||
58 => BlsFr::from_be_bytes_mod_order(&[
|
||||
15, 30, 117, 5, 235, 217, 29, 47, 199, 156, 45, 247, 220, 152, 163, 190, 209, 179, 105,
|
||||
104, 186, 4, 5, 192, 144, 210, 127, 106, 0, 183, 223, 200,
|
||||
]),
|
||||
59 => BlsFr::from_be_bytes_mod_order(&[
|
||||
47, 49, 63, 175, 13, 63, 97, 135, 83, 122, 116, 151, 163, 180, 63, 70, 121, 127, 214,
|
||||
227, 241, 142, 177, 202, 255, 69, 119, 86, 184, 25, 187, 32,
|
||||
]),
|
||||
60 => BlsFr::from_be_bytes_mod_order(&[
|
||||
58, 92, 187, 109, 228, 80, 180, 129, 250, 60, 166, 28, 14, 209, 91, 197, 92, 173, 17,
|
||||
235, 240, 247, 206, 184, 240, 188, 62, 115, 46, 203, 38, 246,
|
||||
]),
|
||||
61 => BlsFr::from_be_bytes_mod_order(&[
|
||||
104, 29, 147, 65, 27, 248, 206, 99, 246, 113, 106, 239, 189, 14, 36, 80, 100, 84, 192,
|
||||
52, 142, 227, 143, 171, 235, 38, 71, 2, 113, 76, 207, 148,
|
||||
]),
|
||||
62 => BlsFr::from_be_bytes_mod_order(&[
|
||||
81, 120, 233, 64, 245, 0, 4, 49, 38, 70, 180, 54, 114, 127, 14, 128, 167, 184, 242,
|
||||
233, 238, 31, 220, 103, 124, 72, 49, 167, 103, 39, 119, 251,
|
||||
]),
|
||||
63 => BlsFr::from_be_bytes_mod_order(&[
|
||||
61, 171, 84, 188, 155, 239, 104, 141, 217, 32, 134, 226, 83, 180, 57, 214, 81, 186,
|
||||
166, 226, 15, 137, 43, 98, 134, 85, 39, 203, 202, 145, 89, 130,
|
||||
]),
|
||||
64 => BlsFr::from_be_bytes_mod_order(&[
|
||||
75, 60, 231, 83, 17, 33, 143, 154, 233, 5, 248, 78, 170, 91, 43, 56, 24, 68, 139, 191,
|
||||
57, 114, 225, 170, 214, 157, 227, 33, 0, 144, 21, 208,
|
||||
]),
|
||||
65 => BlsFr::from_be_bytes_mod_order(&[
|
||||
6, 219, 251, 66, 185, 121, 136, 77, 226, 128, 211, 22, 112, 18, 63, 116, 76, 36, 179,
|
||||
59, 65, 15, 239, 212, 54, 128, 69, 172, 242, 183, 26, 227,
|
||||
]),
|
||||
66 => BlsFr::from_be_bytes_mod_order(&[
|
||||
6, 141, 107, 70, 8, 170, 232, 16, 198, 240, 57, 234, 25, 115, 166, 62, 184, 210, 222,
|
||||
114, 227, 210, 201, 236, 167, 252, 50, 210, 47, 24, 185, 211,
|
||||
]),
|
||||
67 => BlsFr::from_be_bytes_mod_order(&[
|
||||
76, 92, 37, 69, 137, 169, 42, 54, 8, 74, 87, 211, 177, 217, 100, 39, 138, 204, 126, 79,
|
||||
232, 246, 159, 41, 85, 149, 79, 39, 167, 156, 235, 239,
|
||||
]),
|
||||
68 => BlsFr::from_be_bytes_mod_order(&[
|
||||
108, 186, 197, 225, 112, 9, 132, 235, 195, 45, 161, 91, 75, 185, 104, 63, 170, 186,
|
||||
181, 95, 103, 204, 196, 247, 29, 149, 96, 179, 71, 90, 119, 235,
|
||||
]),
|
||||
69 => BlsFr::from_be_bytes_mod_order(&[
|
||||
70, 3, 196, 3, 187, 250, 154, 23, 115, 138, 92, 98, 120, 234, 171, 28, 55, 236, 48,
|
||||
176, 115, 122, 162, 64, 159, 196, 137, 128, 105, 235, 152, 60,
|
||||
]),
|
||||
70 => BlsFr::from_be_bytes_mod_order(&[
|
||||
104, 148, 231, 226, 43, 44, 29, 92, 112, 167, 18, 166, 52, 90, 230, 177, 146, 169, 200,
|
||||
51, 169, 35, 76, 49, 197, 106, 172, 209, 107, 194, 241, 0,
|
||||
]),
|
||||
71 => BlsFr::from_be_bytes_mod_order(&[
|
||||
91, 226, 203, 188, 68, 5, 58, 208, 138, 250, 77, 30, 171, 199, 243, 210, 49, 238, 167,
|
||||
153, 185, 63, 34, 110, 144, 91, 125, 77, 101, 197, 142, 187,
|
||||
]),
|
||||
72 => BlsFr::from_be_bytes_mod_order(&[
|
||||
88, 229, 95, 40, 123, 69, 58, 152, 8, 98, 74, 140, 42, 53, 61, 82, 141, 160, 247, 231,
|
||||
19, 165, 198, 208, 215, 113, 30, 71, 6, 63, 166, 17,
|
||||
]),
|
||||
73 => BlsFr::from_be_bytes_mod_order(&[
|
||||
54, 110, 191, 175, 163, 173, 56, 28, 14, 226, 88, 201, 184, 253, 252, 205, 184, 104,
|
||||
167, 215, 225, 241, 246, 154, 43, 93, 252, 197, 87, 37, 85, 223,
|
||||
]),
|
||||
74 => BlsFr::from_be_bytes_mod_order(&[
|
||||
69, 118, 106, 183, 40, 150, 140, 100, 47, 144, 217, 124, 207, 85, 4, 221, 193, 5, 24,
|
||||
168, 25, 235, 188, 196, 208, 156, 63, 93, 120, 77, 103, 206,
|
||||
]),
|
||||
75 => BlsFr::from_be_bytes_mod_order(&[
|
||||
57, 103, 143, 101, 81, 47, 30, 228, 4, 219, 48, 36, 244, 29, 63, 86, 126, 246, 109,
|
||||
137, 208, 68, 208, 34, 230, 188, 34, 158, 149, 188, 118, 177,
|
||||
]),
|
||||
76 => BlsFr::from_be_bytes_mod_order(&[
|
||||
70, 58, 237, 29, 47, 31, 149, 94, 48, 120, 190, 91, 247, 191, 196, 111, 192, 235, 140,
|
||||
81, 85, 25, 6, 168, 134, 143, 24, 255, 174, 48, 207, 79,
|
||||
]),
|
||||
77 => BlsFr::from_be_bytes_mod_order(&[
|
||||
33, 102, 143, 1, 106, 128, 99, 192, 213, 139, 119, 80, 163, 188, 47, 225, 207, 130,
|
||||
194, 95, 153, 220, 1, 164, 229, 52, 200, 143, 229, 61, 133, 254,
|
||||
]),
|
||||
78 => BlsFr::from_be_bytes_mod_order(&[
|
||||
57, 208, 9, 148, 168, 165, 4, 106, 27, 199, 73, 54, 62, 152, 167, 104, 227, 77, 234,
|
||||
86, 67, 159, 225, 149, 75, 239, 66, 155, 197, 51, 22, 8,
|
||||
]),
|
||||
79 => BlsFr::from_be_bytes_mod_order(&[
|
||||
77, 127, 93, 205, 120, 236, 233, 169, 51, 152, 77, 227, 44, 11, 72, 250, 194, 187, 169,
|
||||
31, 38, 25, 150, 184, 233, 209, 2, 23, 115, 189, 7, 204,
|
||||
]),
|
||||
_ => BlsFr::ZERO,
|
||||
}
|
||||
}
|
||||
|
||||
/// A channel that can be used to draw random elements from a PoseidonBLS hash.
|
||||
#[derive(Clone, Default)]
|
||||
pub struct PoseidonBLSChannel {
|
||||
digest: BlsFr,
|
||||
pub channel_time: ChannelTime,
|
||||
}
|
||||
|
||||
pub fn poseidon_hash_bls(x: BlsFr, y: BlsFr) -> BlsFr {
|
||||
let mut state = [x, y, BlsFr::ZERO];
|
||||
poseidon_permute_comp_bls(&mut state);
|
||||
state[0] + x
|
||||
}
|
||||
|
||||
pub fn poseidon_permute_comp_bls(state: &mut [BlsFr; 3]) {
|
||||
let mut idx = 0;
|
||||
mix(state);
|
||||
|
||||
// Full rounds
|
||||
for _ in 0..4 {
|
||||
round_comp(state, idx, true);
|
||||
idx += 3;
|
||||
}
|
||||
|
||||
// Partial rounds
|
||||
for _ in 0..56 {
|
||||
round_comp(state, idx, false);
|
||||
idx += 1;
|
||||
}
|
||||
|
||||
// Full rounds
|
||||
for _ in 0..4 {
|
||||
round_comp(state, idx, true);
|
||||
idx += 3;
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn round_comp(state: &mut [BlsFr; 3], idx: usize, full: bool) {
|
||||
if full {
|
||||
state[0] += poseidon_comp_consts(idx);
|
||||
state[1] += poseidon_comp_consts(idx + 1);
|
||||
state[2] += poseidon_comp_consts(idx + 2);
|
||||
// Optimize multiplication
|
||||
state[0] = state[0] * state[0] * state[0] * state[0] * state[0];
|
||||
state[1] = state[1] * state[1] * state[1] * state[1] * state[1];
|
||||
state[2] = state[2] * state[2] * state[2] * state[2] * state[2];
|
||||
} else {
|
||||
state[0] += poseidon_comp_consts(idx);
|
||||
state[2] = state[2] * state[2] * state[2] * state[2] * state[2];
|
||||
}
|
||||
mix(state);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn mix(state: &mut [BlsFr; 3]) {
|
||||
state[0] = state[0] + state[1] + state[2];
|
||||
state[1] = state[0] + state[1];
|
||||
state[2] = state[0] + state[2];
|
||||
}
|
||||
|
||||
pub fn poseidon_hash_many_bls(msgs: &[BlsFr]) -> BlsFr {
|
||||
let mut state = [BlsFr::ZERO, BlsFr::ZERO, BlsFr::ZERO];
|
||||
let mut iter = msgs.chunks_exact(2);
|
||||
|
||||
for msg in iter.by_ref() {
|
||||
state[0] += msg[0];
|
||||
state[1] += msg[1];
|
||||
poseidon_permute_comp_bls(&mut state);
|
||||
}
|
||||
let r = iter.remainder();
|
||||
if r.len() == 1 {
|
||||
state[0] += r[0];
|
||||
}
|
||||
state[r.len()] += BlsFr::ONE;
|
||||
poseidon_permute_comp_bls(&mut state);
|
||||
|
||||
state[0]
|
||||
}
|
||||
|
||||
impl PoseidonBLSChannel {
|
||||
pub fn digest(&self) -> BlsFr {
|
||||
self.digest
|
||||
}
|
||||
pub fn update_digest(&mut self, new_digest: BlsFr) {
|
||||
self.digest = new_digest;
|
||||
self.channel_time.inc_challenges();
|
||||
}
|
||||
fn draw_felt252(&mut self) -> BlsFr {
|
||||
let res = poseidon_hash_bls(self.digest, BlsFr::from(self.channel_time.n_sent as u64));
|
||||
self.channel_time.inc_sent();
|
||||
res
|
||||
}
|
||||
|
||||
// TODO(spapini): Understand if we really need uniformity here.
|
||||
/// Generates a close-to uniform random vector of BaseField elements.
|
||||
fn draw_base_felts(&mut self) -> [BaseField; 8] {
|
||||
let shift = NonZero::new(U256::from_u64(1u64 << 31)).unwrap();
|
||||
|
||||
let mut cur = self.draw_felt252();
|
||||
let u32s: [u32; 8usize] = std::array::from_fn(|_| {
|
||||
let (quotient, reminder) =
|
||||
U256::from_be_slice(&cur.into_bigint().to_bytes_be()).div_rem(&shift);
|
||||
cur = BlsFr::from_be_bytes_mod_order("ient.to_be_bytes());
|
||||
u32::from_str_radix(&reminder.to_string(),16).unwrap()
|
||||
});
|
||||
|
||||
u32s.into_iter()
|
||||
.map(|x| BaseField::reduce(x as u64))
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl Channel for PoseidonBLSChannel {
|
||||
const BYTES_PER_HASH: usize = BYTES_PER_FELT252;
|
||||
|
||||
fn trailing_zeros(&self) -> u32 {
|
||||
let bytes = self.digest.into_bigint().to_bytes_be();
|
||||
u128::from_le_bytes(std::array::from_fn(|i| bytes[i])).trailing_zeros()
|
||||
}
|
||||
|
||||
// TODO(spapini): Optimize.
|
||||
fn mix_felts(&mut self, felts: &[SecureField]) {
|
||||
let shift = BlsFr::from(1u64 << 31);
|
||||
let mut res = Vec::with_capacity(felts.len() / 2 + 2);
|
||||
res.push(self.digest);
|
||||
for chunk in felts.chunks(2) {
|
||||
res.push(
|
||||
chunk
|
||||
.iter()
|
||||
.flat_map(|x| x.to_m31_array())
|
||||
.fold(BlsFr::default(), |cur, y| {
|
||||
cur * shift + BlsFr::from_be_bytes_mod_order(&y.0.to_be_bytes())
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
// TODO(spapini): do we need length padding?
|
||||
self.update_digest(poseidon_hash_many_bls(&res));
|
||||
}
|
||||
|
||||
fn mix_nonce(&mut self, nonce: u64) {
|
||||
self.update_digest(poseidon_hash_bls(self.digest, nonce.into()));
|
||||
}
|
||||
|
||||
fn draw_felt(&mut self) -> SecureField {
|
||||
let felts: [BaseField; FELTS_PER_HASH] = self.draw_base_felts();
|
||||
SecureField::from_m31_array(felts[..SECURE_EXTENSION_DEGREE].try_into().unwrap())
|
||||
}
|
||||
|
||||
fn draw_felts(&mut self, n_felts: usize) -> Vec<SecureField> {
|
||||
let mut felts = iter::from_fn(|| Some(self.draw_base_felts())).flatten();
|
||||
let secure_felts = iter::from_fn(|| {
|
||||
Some(SecureField::from_m31_array([
|
||||
felts.next()?,
|
||||
felts.next()?,
|
||||
felts.next()?,
|
||||
felts.next()?,
|
||||
]))
|
||||
});
|
||||
secure_felts.take(n_felts).collect()
|
||||
}
|
||||
|
||||
fn draw_random_bytes(&mut self) -> Vec<u8> {
|
||||
let shift = NonZero::new(U256::from_u64(1u64 << 8)).unwrap();
|
||||
let mut cur = self.draw_felt252();
|
||||
let bytes: [u8; 31] = std::array::from_fn(|_| {
|
||||
let (quotient, reminder) =
|
||||
U256::from_be_slice(&cur.into_bigint().to_bytes_be()).div_rem(&shift);
|
||||
cur = BlsFr::from_be_bytes_mod_order("ient.to_be_bytes());
|
||||
u8::from_str_radix(&reminder.to_string(),16).unwrap()
|
||||
});
|
||||
bytes.to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::BTreeSet;
|
||||
|
||||
use crate::core::channel::poseidon_bls::PoseidonBLSChannel;
|
||||
use crate::core::channel::Channel;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::m31;
|
||||
|
||||
#[test]
|
||||
fn test_channel_time() {
|
||||
let mut channel = PoseidonBLSChannel::default();
|
||||
|
||||
assert_eq!(channel.channel_time.n_challenges, 0);
|
||||
assert_eq!(channel.channel_time.n_sent, 0);
|
||||
|
||||
channel.draw_random_bytes();
|
||||
assert_eq!(channel.channel_time.n_challenges, 0);
|
||||
assert_eq!(channel.channel_time.n_sent, 1);
|
||||
|
||||
channel.draw_felts(9);
|
||||
assert_eq!(channel.channel_time.n_challenges, 0);
|
||||
assert_eq!(channel.channel_time.n_sent, 6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_draw_random_bytes() {
|
||||
let mut channel = PoseidonBLSChannel::default();
|
||||
|
||||
let first_random_bytes = channel.draw_random_bytes();
|
||||
|
||||
// Assert that next random bytes are different.
|
||||
assert_ne!(first_random_bytes, channel.draw_random_bytes());
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_draw_felt() {
|
||||
let mut channel = PoseidonBLSChannel::default();
|
||||
|
||||
let first_random_felt = channel.draw_felt();
|
||||
|
||||
// Assert that next random felt is different.
|
||||
assert_ne!(first_random_felt, channel.draw_felt());
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_draw_felts() {
|
||||
let mut channel = PoseidonBLSChannel::default();
|
||||
|
||||
let mut random_felts = channel.draw_felts(5);
|
||||
random_felts.extend(channel.draw_felts(4));
|
||||
|
||||
// Assert that all the random felts are unique.
|
||||
assert_eq!(
|
||||
random_felts.len(),
|
||||
random_felts.iter().collect::<BTreeSet<_>>().len()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_mix_felts() {
|
||||
let mut channel = PoseidonBLSChannel::default();
|
||||
let initial_digest = channel.digest;
|
||||
let felts: Vec<SecureField> = (0..2)
|
||||
.map(|i| SecureField::from(m31!(i + 1923782)))
|
||||
.collect();
|
||||
|
||||
channel.mix_felts(felts.as_slice());
|
||||
|
||||
assert_ne!(initial_digest, channel.digest);
|
||||
}
|
||||
}
|
||||
561
Stwo_wrapper/crates/prover/src/core/circle.rs
Normal file
561
Stwo_wrapper/crates/prover/src/core/circle.rs
Normal file
@ -0,0 +1,561 @@
|
||||
use std::ops::{Add, Div, Mul, Neg, Sub};
|
||||
|
||||
use num_traits::{One, Zero};
|
||||
|
||||
use super::fields::m31::{BaseField, M31};
|
||||
use super::fields::qm31::SecureField;
|
||||
use super::fields::{ComplexConjugate, Field, FieldExpOps};
|
||||
use crate::core::channel::Channel;
|
||||
use crate::core::fields::qm31::P4;
|
||||
use crate::math::utils::egcd;
|
||||
|
||||
/// A point on the complex circle. Treated as an additive group.
|
||||
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
pub struct CirclePoint<F> {
|
||||
pub x: F,
|
||||
pub y: F,
|
||||
}
|
||||
|
||||
impl<F: Zero + Add<Output = F> + FieldExpOps + Sub<Output = F> + Neg<Output = F>> CirclePoint<F> {
|
||||
pub fn zero() -> Self {
|
||||
Self {
|
||||
x: F::one(),
|
||||
y: F::zero(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn double(&self) -> Self {
|
||||
*self + *self
|
||||
}
|
||||
|
||||
/// Applies the circle's x-coordinate doubling map.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use stwo_prover::core::circle::{CirclePoint, M31_CIRCLE_GEN};
|
||||
/// use stwo_prover::core::fields::m31::M31;
|
||||
/// let p = M31_CIRCLE_GEN.mul(17);
|
||||
/// assert_eq!(CirclePoint::double_x(p.x), (p + p).x);
|
||||
/// ```
|
||||
pub fn double_x(x: F) -> F {
|
||||
let sx = x.square();
|
||||
sx + sx - F::one()
|
||||
}
|
||||
|
||||
/// Returns the log order of a point.
|
||||
///
|
||||
/// All points have an order of the form `2^k`.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use stwo_prover::core::circle::{CirclePoint, M31_CIRCLE_GEN, M31_CIRCLE_LOG_ORDER};
|
||||
/// use stwo_prover::core::fields::m31::M31;
|
||||
/// assert_eq!(M31_CIRCLE_GEN.log_order(), M31_CIRCLE_LOG_ORDER);
|
||||
/// ```
|
||||
pub fn log_order(&self) -> u32
|
||||
where
|
||||
F: PartialEq + Eq,
|
||||
{
|
||||
// we only need the x-coordinate to check order since the only point
|
||||
// with x=1 is the circle's identity
|
||||
let mut res = 0;
|
||||
let mut cur = self.x;
|
||||
while cur != F::one() {
|
||||
cur = Self::double_x(cur);
|
||||
res += 1;
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
pub fn mul(&self, mut scalar: u128) -> CirclePoint<F> {
|
||||
let mut res = Self::zero();
|
||||
let mut cur = *self;
|
||||
while scalar > 0 {
|
||||
if scalar & 1 == 1 {
|
||||
res = res + cur;
|
||||
}
|
||||
cur = cur.double();
|
||||
scalar >>= 1;
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
pub fn repeated_double(&self, n: u32) -> Self {
|
||||
let mut res = *self;
|
||||
for _ in 0..n {
|
||||
res = res.double();
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
pub fn conjugate(&self) -> CirclePoint<F> {
|
||||
Self {
|
||||
x: self.x,
|
||||
y: -self.y,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn antipode(&self) -> CirclePoint<F> {
|
||||
Self {
|
||||
x: -self.x,
|
||||
y: -self.y,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_ef<EF: From<F>>(&self) -> CirclePoint<EF> {
|
||||
CirclePoint {
|
||||
x: self.x.into(),
|
||||
y: self.y.into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn mul_signed(&self, off: isize) -> CirclePoint<F> {
|
||||
if off > 0 {
|
||||
self.mul(off as u128)
|
||||
} else {
|
||||
self.conjugate().mul(-off as u128)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Zero + Add<Output = F> + FieldExpOps + Sub<Output = F> + Neg<Output = F>> Add
|
||||
for CirclePoint<F>
|
||||
{
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
let x = self.x * rhs.x - self.y * rhs.y;
|
||||
let y = self.x * rhs.y + self.y * rhs.x;
|
||||
Self { x, y }
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Zero + Add<Output = F> + FieldExpOps + Sub<Output = F> + Neg<Output = F>> Neg
|
||||
for CirclePoint<F>
|
||||
{
|
||||
type Output = Self;
|
||||
|
||||
fn neg(self) -> Self::Output {
|
||||
self.conjugate()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Zero + Add<Output = F> + FieldExpOps + Sub<Output = F> + Neg<Output = F>> Sub
|
||||
for CirclePoint<F>
|
||||
{
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
self + (-rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> ComplexConjugate for CirclePoint<F> {
|
||||
fn complex_conjugate(&self) -> Self {
|
||||
Self {
|
||||
x: self.x.complex_conjugate(),
|
||||
y: self.y.complex_conjugate(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CirclePoint<SecureField> {
|
||||
pub fn get_point(index: u128) -> Self {
|
||||
assert!(index < SECURE_FIELD_CIRCLE_ORDER);
|
||||
SECURE_FIELD_CIRCLE_GEN.mul(index)
|
||||
}
|
||||
|
||||
pub fn get_random_point<C: Channel>(channel: &mut C) -> Self {
|
||||
let t = channel.draw_felt();
|
||||
let t_square = t.square();
|
||||
|
||||
let one_plus_tsquared_inv = t_square.add(SecureField::one()).inverse();
|
||||
|
||||
let x = SecureField::one()
|
||||
.add(t_square.neg())
|
||||
.mul(one_plus_tsquared_inv);
|
||||
let y = t.double().mul(one_plus_tsquared_inv);
|
||||
|
||||
Self { x, y }
|
||||
}
|
||||
}
|
||||
|
||||
/// A generator for the circle group over [M31].
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use stwo_prover::core::circle::{CirclePoint, M31_CIRCLE_GEN};
|
||||
/// use stwo_prover::core::fields::m31::M31;
|
||||
///
|
||||
/// // Adding a generator to itself (2^30) times should NOT yield the identity.
|
||||
/// let circle_point = M31_CIRCLE_GEN.repeated_double(30);
|
||||
/// assert_ne!(circle_point, CirclePoint::zero());
|
||||
///
|
||||
/// // Shown above ord(M31_CIRCLE_GEN) > 2^30 . Group order is 2^31.
|
||||
/// // Ord(M31_CIRCLE_GEN) must be a divisor of it, Hence ord(M31_CIRCLE_GEN) = 2^31.
|
||||
/// // Adding the generator to itself (2^31) times should yield the identity.
|
||||
/// let circle_point = M31_CIRCLE_GEN.repeated_double(31);
|
||||
/// assert_eq!(circle_point, CirclePoint::zero());
|
||||
/// ```
|
||||
pub const M31_CIRCLE_GEN: CirclePoint<M31> = CirclePoint {
|
||||
x: M31::from_u32_unchecked(2),
|
||||
y: M31::from_u32_unchecked(1268011823),
|
||||
};
|
||||
|
||||
/// Order of [M31_CIRCLE_GEN].
|
||||
pub const M31_CIRCLE_LOG_ORDER: u32 = 31;
|
||||
|
||||
/// A generator for the circle group over [SecureField].
|
||||
pub const SECURE_FIELD_CIRCLE_GEN: CirclePoint<SecureField> = CirclePoint {
|
||||
x: SecureField::from_u32_unchecked(1, 0, 478637715, 513582971),
|
||||
y: SecureField::from_u32_unchecked(992285211, 649143431, 740191619, 1186584352),
|
||||
};
|
||||
|
||||
/// Order of [SECURE_FIELD_CIRCLE_GEN].
|
||||
pub const SECURE_FIELD_CIRCLE_ORDER: u128 = P4 - 1;
|
||||
|
||||
/// Integer i that represent the circle point i * CIRCLE_GEN. Treated as an
|
||||
/// additive ring modulo `1 << M31_CIRCLE_LOG_ORDER`.
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, Ord, PartialOrd)]
|
||||
pub struct CirclePointIndex(pub usize);
|
||||
|
||||
impl CirclePointIndex {
|
||||
pub fn zero() -> Self {
|
||||
Self(0)
|
||||
}
|
||||
|
||||
pub fn generator() -> Self {
|
||||
Self(1)
|
||||
}
|
||||
|
||||
pub fn reduce(self) -> Self {
|
||||
Self(self.0 & ((1 << M31_CIRCLE_LOG_ORDER) - 1))
|
||||
}
|
||||
|
||||
pub fn subgroup_gen(log_size: u32) -> Self {
|
||||
assert!(log_size <= M31_CIRCLE_LOG_ORDER);
|
||||
Self(1 << (M31_CIRCLE_LOG_ORDER - log_size))
|
||||
}
|
||||
|
||||
pub fn to_point(self) -> CirclePoint<M31> {
|
||||
M31_CIRCLE_GEN.mul(self.0 as u128)
|
||||
}
|
||||
|
||||
pub fn half(self) -> Self {
|
||||
assert!(self.0 & 1 == 0);
|
||||
Self(self.0 >> 1)
|
||||
}
|
||||
|
||||
pub fn try_div(&self, rhs: CirclePointIndex) -> Option<isize> {
|
||||
// Find x s.t. x * rhs.0 = self.0 (mod CIRCLE_ORDER).
|
||||
let (s, _t, g) = egcd(rhs.0 as isize, 1 << M31_CIRCLE_LOG_ORDER);
|
||||
if self.0 as isize % g != 0 {
|
||||
return None;
|
||||
}
|
||||
let res = s * self.0 as isize / g;
|
||||
Some(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl Add for CirclePointIndex {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
Self(self.0 + rhs.0).reduce()
|
||||
}
|
||||
}
|
||||
|
||||
impl Sub for CirclePointIndex {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
Self(self.0 + (1 << M31_CIRCLE_LOG_ORDER) - rhs.0).reduce()
|
||||
}
|
||||
}
|
||||
|
||||
impl Mul<usize> for CirclePointIndex {
|
||||
type Output = Self;
|
||||
|
||||
fn mul(self, rhs: usize) -> Self::Output {
|
||||
Self(self.0.wrapping_mul(rhs)).reduce()
|
||||
}
|
||||
}
|
||||
|
||||
impl Div for CirclePointIndex {
|
||||
type Output = isize;
|
||||
|
||||
fn div(self, rhs: Self) -> Self::Output {
|
||||
self.try_div(rhs).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl Neg for CirclePointIndex {
|
||||
type Output = Self;
|
||||
|
||||
fn neg(self) -> Self::Output {
|
||||
Self((1 << M31_CIRCLE_LOG_ORDER) - self.0).reduce()
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the coset initial + \<step\>.
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
|
||||
pub struct Coset {
|
||||
pub initial_index: CirclePointIndex,
|
||||
pub initial: CirclePoint<M31>,
|
||||
pub step_size: CirclePointIndex,
|
||||
pub step: CirclePoint<M31>,
|
||||
pub log_size: u32,
|
||||
}
|
||||
|
||||
impl Coset {
|
||||
pub fn new(initial_index: CirclePointIndex, log_size: u32) -> Self {
|
||||
assert!(log_size <= M31_CIRCLE_LOG_ORDER);
|
||||
let step_size = CirclePointIndex::subgroup_gen(log_size);
|
||||
Self {
|
||||
initial_index,
|
||||
initial: initial_index.to_point(),
|
||||
step: step_size.to_point(),
|
||||
step_size,
|
||||
log_size,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a coset of the form <G_n>.
|
||||
/// For example, for n=8, we get the point indices \[0,1,2,3,4,5,6,7\].
|
||||
pub fn subgroup(log_size: u32) -> Self {
|
||||
Self::new(CirclePointIndex::zero(), log_size)
|
||||
}
|
||||
|
||||
/// Creates a coset of the form G_2n + \<G_n\>.
|
||||
/// For example, for n=8, we get the point indices \[1,3,5,7,9,11,13,15\].
|
||||
pub fn odds(log_size: u32) -> Self {
|
||||
Self::new(CirclePointIndex::subgroup_gen(log_size + 1), log_size)
|
||||
}
|
||||
|
||||
/// Creates a coset of the form G_4n + <G_n>.
|
||||
/// For example, for n=8, we get the point indices \[1,5,9,13,17,21,25,29\].
|
||||
/// Its conjugate will be \[3,7,11,15,19,23,27,31\].
|
||||
pub fn half_odds(log_size: u32) -> Self {
|
||||
Self::new(CirclePointIndex::subgroup_gen(log_size + 2), log_size)
|
||||
}
|
||||
|
||||
/// Returns the size of the coset.
|
||||
pub fn size(&self) -> usize {
|
||||
1 << self.log_size()
|
||||
}
|
||||
|
||||
/// Returns the log size of the coset.
|
||||
pub fn log_size(&self) -> u32 {
|
||||
self.log_size
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> CosetIterator<CirclePoint<M31>> {
|
||||
CosetIterator {
|
||||
cur: self.initial,
|
||||
step: self.step,
|
||||
remaining: self.size(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn iter_indices(&self) -> CosetIterator<CirclePointIndex> {
|
||||
CosetIterator {
|
||||
cur: self.initial_index,
|
||||
step: self.step_size,
|
||||
remaining: self.size(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a new coset comprising of all points in current coset doubled.
|
||||
pub fn double(&self) -> Self {
|
||||
assert!(self.log_size > 0);
|
||||
Self {
|
||||
initial_index: self.initial_index * 2,
|
||||
initial: self.initial.double(),
|
||||
step: self.step.double(),
|
||||
step_size: self.step_size * 2,
|
||||
log_size: self.log_size.saturating_sub(1),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn repeated_double(&self, n_doubles: u32) -> Self {
|
||||
(0..n_doubles).fold(*self, |coset, _| coset.double())
|
||||
}
|
||||
|
||||
pub fn is_doubling_of(&self, other: Self) -> bool {
|
||||
self.log_size <= other.log_size
|
||||
&& *self == other.repeated_double(other.log_size - self.log_size)
|
||||
}
|
||||
|
||||
pub fn initial(&self) -> CirclePoint<M31> {
|
||||
self.initial
|
||||
}
|
||||
|
||||
pub fn index_at(&self, index: usize) -> CirclePointIndex {
|
||||
self.initial_index + self.step_size.mul(index)
|
||||
}
|
||||
|
||||
pub fn at(&self, index: usize) -> CirclePoint<M31> {
|
||||
self.index_at(index).to_point()
|
||||
}
|
||||
|
||||
pub fn shift(&self, shift_size: CirclePointIndex) -> Self {
|
||||
let initial_index = self.initial_index + shift_size;
|
||||
Self {
|
||||
initial_index,
|
||||
initial: initial_index.to_point(),
|
||||
..*self
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates the conjugate coset: -initial -\<step\>.
|
||||
pub fn conjugate(&self) -> Self {
|
||||
let initial_index = -self.initial_index;
|
||||
let step_size = -self.step_size;
|
||||
Self {
|
||||
initial_index,
|
||||
initial: initial_index.to_point(),
|
||||
step_size,
|
||||
step: step_size.to_point(),
|
||||
log_size: self.log_size,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn find(&self, i: CirclePointIndex) -> Option<usize> {
|
||||
let res = (i - self.initial_index).try_div(self.step_size)?;
|
||||
Some(res.rem_euclid(self.size() as isize) as usize)
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoIterator for Coset {
|
||||
type Item = CirclePoint<BaseField>;
|
||||
type IntoIter = CosetIterator<CirclePoint<BaseField>>;
|
||||
|
||||
/// Iterates over the points in the coset.
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.iter()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CosetIterator<T: Add> {
|
||||
pub cur: T,
|
||||
pub step: T,
|
||||
pub remaining: usize,
|
||||
}
|
||||
|
||||
impl<T: Add<Output = T> + Copy> Iterator for CosetIterator<T> {
|
||||
type Item = T;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.remaining == 0 {
|
||||
return None;
|
||||
}
|
||||
self.remaining -= 1;
|
||||
let res = self.cur;
|
||||
self.cur = self.cur + self.step;
|
||||
Some(res)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::BTreeSet;
|
||||
|
||||
use num_traits::{One, Pow};
|
||||
|
||||
use super::{CirclePointIndex, Coset};
|
||||
use crate::core::channel::Blake2sChannel;
|
||||
use crate::core::circle::{CirclePoint, SECURE_FIELD_CIRCLE_GEN};
|
||||
use crate::core::fields::qm31::{SecureField, P4};
|
||||
use crate::core::fields::FieldExpOps;
|
||||
use crate::core::poly::circle::CanonicCoset;
|
||||
|
||||
#[test]
|
||||
fn test_iterator() {
|
||||
let coset = Coset::new(CirclePointIndex(1), 3);
|
||||
let actual_indices: Vec<_> = coset.iter_indices().collect();
|
||||
let expected_indices = vec![
|
||||
CirclePointIndex(1),
|
||||
CirclePointIndex(1) + CirclePointIndex::subgroup_gen(3) * 1,
|
||||
CirclePointIndex(1) + CirclePointIndex::subgroup_gen(3) * 2,
|
||||
CirclePointIndex(1) + CirclePointIndex::subgroup_gen(3) * 3,
|
||||
CirclePointIndex(1) + CirclePointIndex::subgroup_gen(3) * 4,
|
||||
CirclePointIndex(1) + CirclePointIndex::subgroup_gen(3) * 5,
|
||||
CirclePointIndex(1) + CirclePointIndex::subgroup_gen(3) * 6,
|
||||
CirclePointIndex(1) + CirclePointIndex::subgroup_gen(3) * 7,
|
||||
];
|
||||
assert_eq!(actual_indices, expected_indices);
|
||||
|
||||
let actual_points = coset.iter().collect::<Vec<_>>();
|
||||
let expected_points: Vec<_> = expected_indices.iter().map(|i| i.to_point()).collect();
|
||||
assert_eq!(actual_points, expected_points);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_coset_is_half_coset_with_conjugate() {
|
||||
let canonic_coset = CanonicCoset::new(8);
|
||||
let coset_points = BTreeSet::from_iter(canonic_coset.coset().iter());
|
||||
|
||||
let half_coset_points = BTreeSet::from_iter(canonic_coset.half_coset().iter());
|
||||
let half_coset_conjugate_points =
|
||||
BTreeSet::from_iter(canonic_coset.half_coset().conjugate().iter());
|
||||
|
||||
assert!((&half_coset_points & &half_coset_conjugate_points).is_empty());
|
||||
assert_eq!(
|
||||
coset_points,
|
||||
&half_coset_points | &half_coset_conjugate_points
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_get_random_circle_point() {
|
||||
let mut channel = Blake2sChannel::default();
|
||||
|
||||
let first_random_circle_point = CirclePoint::get_random_point(&mut channel);
|
||||
|
||||
// Assert that the next random circle point is different.
|
||||
assert_ne!(
|
||||
first_random_circle_point,
|
||||
CirclePoint::get_random_point(&mut channel)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_secure_field_circle_gen() {
|
||||
let prime_factors = [
|
||||
(2, 33),
|
||||
(3, 2),
|
||||
(5, 1),
|
||||
(7, 1),
|
||||
(11, 1),
|
||||
(31, 1),
|
||||
(151, 1),
|
||||
(331, 1),
|
||||
(733, 1),
|
||||
(1709, 1),
|
||||
(368140581013, 1),
|
||||
];
|
||||
|
||||
assert_eq!(
|
||||
prime_factors
|
||||
.iter()
|
||||
.map(|(p, e)| p.pow(*e as u32))
|
||||
.product::<u128>(),
|
||||
P4 - 1
|
||||
);
|
||||
assert_eq!(
|
||||
SECURE_FIELD_CIRCLE_GEN.x.square() + SECURE_FIELD_CIRCLE_GEN.y.square(),
|
||||
SecureField::one()
|
||||
);
|
||||
assert_eq!(SECURE_FIELD_CIRCLE_GEN.mul(P4 - 1), CirclePoint::zero());
|
||||
for (p, _) in prime_factors.iter() {
|
||||
assert_ne!(
|
||||
SECURE_FIELD_CIRCLE_GEN.mul((P4 - 1) / *p),
|
||||
CirclePoint::zero()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
251
Stwo_wrapper/crates/prover/src/core/constraints.rs
Normal file
251
Stwo_wrapper/crates/prover/src/core/constraints.rs
Normal file
@ -0,0 +1,251 @@
|
||||
use num_traits::One;
|
||||
|
||||
use super::circle::{CirclePoint, Coset};
|
||||
use super::fields::m31::BaseField;
|
||||
use super::fields::qm31::SecureField;
|
||||
use super::fields::ExtensionOf;
|
||||
use super::pcs::quotients::PointSample;
|
||||
use crate::core::fields::ComplexConjugate;
|
||||
|
||||
/// Evaluates a vanishing polynomial of the coset at a point.
|
||||
pub fn coset_vanishing<F: ExtensionOf<BaseField>>(coset: Coset, mut p: CirclePoint<F>) -> F {
|
||||
// Doubling a point `log_order - 1` times and taking the x coordinate is
|
||||
// essentially evaluating a polynomial in x of degree `2^(log_order - 1)`. If
|
||||
// the entire `2^log_order` points of the coset are roots (i.e. yield 0), then
|
||||
// this is a vanishing polynomial of these points.
|
||||
|
||||
// Rotating the coset -coset.initial + step / 2 yields a canonic coset:
|
||||
// `step/2 + <step>.`
|
||||
// Doubling this coset log_order - 1 times yields the coset +-G_4.
|
||||
// The polynomial x vanishes on these points.
|
||||
// ```text
|
||||
// X
|
||||
// . .
|
||||
// X
|
||||
// ```
|
||||
p = p - coset.initial.into_ef() + coset.step_size.half().to_point().into_ef();
|
||||
//println!("p - coset.initial.into_ef() + coset.step_size.half().to_point().into_ef() = {:?} - {:?} + {:?}",p,coset.initial,coset.step_size.half().0);
|
||||
let mut x = p.x;
|
||||
|
||||
// The formula for the x coordinate of the double of a point.
|
||||
for _ in 1..coset.log_size {
|
||||
x = CirclePoint::double_x(x);
|
||||
}
|
||||
x
|
||||
}
|
||||
|
||||
/// Evaluates the polynomial that is used to exclude the excluded point at point
|
||||
/// p. Note that this polynomial has a zero of multiplicity 2 at the excluded
|
||||
/// point.
|
||||
pub fn point_excluder<F: ExtensionOf<BaseField>>(
|
||||
excluded: CirclePoint<BaseField>,
|
||||
p: CirclePoint<F>,
|
||||
) -> F {
|
||||
(p - excluded.into_ef()).x - BaseField::one()
|
||||
}
|
||||
|
||||
// A vanishing polynomial on 2 circle points.
|
||||
pub fn pair_vanishing<F: ExtensionOf<BaseField>>(
|
||||
excluded0: CirclePoint<F>,
|
||||
excluded1: CirclePoint<F>,
|
||||
p: CirclePoint<F>,
|
||||
) -> F {
|
||||
// The algorithm check computes the area of the triangle formed by the
|
||||
// 3 points. This is done using the determinant of:
|
||||
// | p.x p.y 1 |
|
||||
// | e0.x e0.y 1 |
|
||||
// | e1.x e1.y 1 |
|
||||
// This is a polynomial of degree 1 in p.x and p.y, and thus it is a line.
|
||||
// It vanishes at e0 and e1.
|
||||
(excluded0.y - excluded1.y) * p.x
|
||||
+ (excluded1.x - excluded0.x) * p.y
|
||||
+ (excluded0.x * excluded1.y - excluded0.y * excluded1.x)
|
||||
}
|
||||
|
||||
/// Evaluates a vanishing polynomial of the vanish_point at a point.
|
||||
/// Note that this function has a pole on the antipode of the vanish_point.
|
||||
pub fn point_vanishing<F: ExtensionOf<BaseField>, EF: ExtensionOf<F>>(
|
||||
vanish_point: CirclePoint<F>,
|
||||
p: CirclePoint<EF>,
|
||||
) -> EF {
|
||||
let h = p - vanish_point.into_ef();
|
||||
h.y / (EF::one() + h.x)
|
||||
}
|
||||
|
||||
/// Evaluates a point on a line between a point and its complex conjugate.
|
||||
/// Relies on the fact that every polynomial F over the base field holds:
|
||||
/// F(p*) == F(p)* (* being the complex conjugate).
|
||||
pub fn complex_conjugate_line(
|
||||
point: CirclePoint<SecureField>,
|
||||
value: SecureField,
|
||||
p: CirclePoint<BaseField>,
|
||||
) -> SecureField {
|
||||
// TODO(AlonH): This assertion will fail at a probability of 1 to 2^62. Use a better solution.
|
||||
assert_ne!(
|
||||
point.y,
|
||||
point.y.complex_conjugate(),
|
||||
"Cannot evaluate a line with a single point ({point:?})."
|
||||
);
|
||||
value
|
||||
+ (value.complex_conjugate() - value) * (-point.y + p.y)
|
||||
/ (point.complex_conjugate().y - point.y)
|
||||
}
|
||||
|
||||
/// Evaluates the coefficients of a line between a point and its complex conjugate. Specifically,
|
||||
/// `a, b, and c, s.t. a*x + b -c*y = 0` for (x,y) being (sample.y, sample.value) and
|
||||
/// (conj(sample.y), conj(sample.value)).
|
||||
/// Relies on the fact that every polynomial F over the base
|
||||
/// field holds: F(p*) == F(p)* (* being the complex conjugate).
|
||||
pub fn complex_conjugate_line_coeffs(
|
||||
sample: &PointSample,
|
||||
alpha: SecureField,
|
||||
) -> (SecureField, SecureField, SecureField) {
|
||||
// TODO(AlonH): This assertion will fail at a probability of 1 to 2^62. Use a better solution.
|
||||
assert_ne!(
|
||||
sample.point.y,
|
||||
sample.point.y.complex_conjugate(),
|
||||
"Cannot evaluate a line with a single point ({:?}).",
|
||||
sample.point
|
||||
);
|
||||
let a = sample.value.complex_conjugate() - sample.value;
|
||||
let c = sample.point.complex_conjugate().y - sample.point.y;
|
||||
let b = sample.value * c - a * sample.point.y;
|
||||
(alpha * a, alpha * b, alpha * c)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use num_traits::Zero;
|
||||
|
||||
use super::{coset_vanishing, point_excluder, point_vanishing};
|
||||
use crate::core::backend::cpu::{CpuCircleEvaluation, CpuCirclePoly};
|
||||
use crate::core::circle::{CirclePoint, CirclePointIndex, Coset};
|
||||
use crate::core::constraints::{complex_conjugate_line, pair_vanishing};
|
||||
use crate::core::fields::m31::{BaseField, M31};
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::fields::{ComplexConjugate, FieldExpOps};
|
||||
use crate::core::poly::circle::CanonicCoset;
|
||||
use crate::core::poly::NaturalOrder;
|
||||
use crate::core::test_utils::secure_eval_to_base_eval;
|
||||
use crate::m31;
|
||||
|
||||
#[test]
|
||||
fn test_coset_vanishing() {
|
||||
let cosets = [
|
||||
Coset::half_odds(5),
|
||||
Coset::odds(5),
|
||||
Coset::new(CirclePointIndex::zero(), 5),
|
||||
Coset::half_odds(5).conjugate(),
|
||||
];
|
||||
for c0 in cosets.iter() {
|
||||
for el in c0.iter() {
|
||||
assert_eq!(coset_vanishing(*c0, el), BaseField::zero());
|
||||
for c1 in cosets.iter() {
|
||||
if c0 == c1 {
|
||||
continue;
|
||||
}
|
||||
assert_ne!(coset_vanishing(*c1, el), BaseField::zero());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_point_excluder() {
|
||||
let excluded = Coset::half_odds(5).at(10);
|
||||
let point = (CirclePointIndex::generator() * 4).to_point();
|
||||
|
||||
let num = point_excluder(excluded, point) * point_excluder(excluded.conjugate(), point);
|
||||
let denom = (point.x - excluded.x).pow(2);
|
||||
|
||||
assert_eq!(num, denom);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pair_excluder() {
|
||||
let excluded0 = Coset::half_odds(5).at(10);
|
||||
let excluded1 = Coset::half_odds(5).at(13);
|
||||
let point = (CirclePointIndex::generator() * 4).to_point();
|
||||
|
||||
assert_ne!(pair_vanishing(excluded0, excluded1, point), M31::zero());
|
||||
assert_eq!(pair_vanishing(excluded0, excluded1, excluded0), M31::zero());
|
||||
assert_eq!(pair_vanishing(excluded0, excluded1, excluded1), M31::zero());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_point_vanishing_success() {
|
||||
let coset = Coset::odds(5);
|
||||
let vanish_point = coset.at(2);
|
||||
for el in coset.iter() {
|
||||
if el == vanish_point {
|
||||
assert_eq!(point_vanishing(vanish_point, el), BaseField::zero());
|
||||
continue;
|
||||
}
|
||||
if el == vanish_point.antipode() {
|
||||
continue;
|
||||
}
|
||||
assert_ne!(point_vanishing(vanish_point, el), BaseField::zero());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "0 has no inverse")]
|
||||
fn test_point_vanishing_failure() {
|
||||
let coset = Coset::half_odds(6);
|
||||
let point = coset.at(4);
|
||||
point_vanishing(point, point.antipode());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_complex_conjugate_symmetry() {
|
||||
// Create a polynomial over a base circle domain.
|
||||
let polynomial = CpuCirclePoly::new((0..1 << 7).map(|i| m31!(i)).collect());
|
||||
let oods_point = CirclePoint::get_point(9834759221);
|
||||
|
||||
// Assert that the base field polynomial is complex conjugate symmetric.
|
||||
assert_eq!(
|
||||
polynomial.eval_at_point(oods_point.complex_conjugate()),
|
||||
polynomial.eval_at_point(oods_point).complex_conjugate()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_point_vanishing_degree() {
|
||||
// Create a polynomial over a circle domain.
|
||||
let log_domain_size = 7;
|
||||
let domain_size = 1 << log_domain_size;
|
||||
let polynomial = CpuCirclePoly::new((0..domain_size).map(|i| m31!(i)).collect());
|
||||
|
||||
// Create a larger domain.
|
||||
let log_large_domain_size = log_domain_size + 1;
|
||||
let large_domain_size = 1 << log_large_domain_size;
|
||||
let large_domain = CanonicCoset::new(log_large_domain_size).circle_domain();
|
||||
|
||||
// Create a vanish point that is not in the large domain.
|
||||
let vanish_point = CirclePoint::get_point(97);
|
||||
let vanish_point_value = polynomial.eval_at_point(vanish_point);
|
||||
|
||||
// Compute the quotient polynomial.
|
||||
let mut quotient_polynomial_values = Vec::with_capacity(large_domain_size as usize);
|
||||
for point in large_domain.iter() {
|
||||
let line = complex_conjugate_line(vanish_point, vanish_point_value, point);
|
||||
let mut value = polynomial.eval_at_point(point.into_ef()) - line;
|
||||
value /= pair_vanishing(
|
||||
vanish_point,
|
||||
vanish_point.complex_conjugate(),
|
||||
point.into_ef(),
|
||||
);
|
||||
quotient_polynomial_values.push(value);
|
||||
}
|
||||
let quotient_evaluation = CpuCircleEvaluation::<SecureField, NaturalOrder>::new(
|
||||
large_domain,
|
||||
quotient_polynomial_values,
|
||||
);
|
||||
let quotient_polynomial = secure_eval_to_base_eval("ient_evaluation)
|
||||
.bit_reverse()
|
||||
.interpolate();
|
||||
|
||||
// Check that the quotient polynomial is indeed in the wanted fft space.
|
||||
assert!(quotient_polynomial.is_in_fft_space(log_domain_size));
|
||||
}
|
||||
}
|
||||
21
Stwo_wrapper/crates/prover/src/core/fft.rs
Normal file
21
Stwo_wrapper/crates/prover/src/core/fft.rs
Normal file
@ -0,0 +1,21 @@
|
||||
use std::ops::{Add, AddAssign, Mul, Sub};
|
||||
|
||||
use super::fields::m31::BaseField;
|
||||
|
||||
pub fn butterfly<F>(v0: &mut F, v1: &mut F, twid: BaseField)
|
||||
where
|
||||
F: Copy + AddAssign<F> + Sub<F, Output = F> + Mul<BaseField, Output = F>,
|
||||
{
|
||||
let tmp = *v1 * twid;
|
||||
*v1 = *v0 - tmp;
|
||||
*v0 += tmp;
|
||||
}
|
||||
|
||||
pub fn ibutterfly<F>(v0: &mut F, v1: &mut F, itwid: BaseField)
|
||||
where
|
||||
F: Copy + AddAssign<F> + Add<F, Output = F> + Sub<F, Output = F> + Mul<BaseField, Output = F>,
|
||||
{
|
||||
let tmp = *v0;
|
||||
*v0 = tmp + *v1;
|
||||
*v1 = (tmp - *v1) * itwid;
|
||||
}
|
||||
137
Stwo_wrapper/crates/prover/src/core/fields/cm31.rs
Normal file
137
Stwo_wrapper/crates/prover/src/core/fields/cm31.rs
Normal file
@ -0,0 +1,137 @@
|
||||
use std::fmt::{Debug, Display};
|
||||
use std::ops::{
|
||||
Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
|
||||
};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::{ComplexConjugate, FieldExpOps};
|
||||
use crate::core::fields::m31::M31;
|
||||
use crate::{impl_extension_field, impl_field};
|
||||
pub const P2: u64 = 4611686014132420609; // (2 ** 31 - 1) ** 2
|
||||
|
||||
/// Complex extension field of M31.
|
||||
/// Equivalent to M31\[x\] over (x^2 + 1) as the irreducible polynomial.
|
||||
/// Represented as (a, b) of a + bi.
|
||||
#[derive(Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)]
|
||||
pub struct CM31(pub M31, pub M31);
|
||||
|
||||
impl_field!(CM31, P2);
|
||||
impl_extension_field!(CM31, M31);
|
||||
|
||||
impl CM31 {
|
||||
pub const fn from_u32_unchecked(a: u32, b: u32) -> CM31 {
|
||||
Self(M31::from_u32_unchecked(a), M31::from_u32_unchecked(b))
|
||||
}
|
||||
|
||||
pub fn from_m31(a: M31, b: M31) -> CM31 {
|
||||
Self(a, b)
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for CM31 {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{} + {}i", self.0, self.1)
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for CM31 {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{} + {}i", self.0, self.1)
|
||||
}
|
||||
}
|
||||
|
||||
impl Mul for CM31 {
|
||||
type Output = Self;
|
||||
|
||||
fn mul(self, rhs: Self) -> Self::Output {
|
||||
// (a + bi) * (c + di) = (ac - bd) + (ad + bc)i.
|
||||
Self(
|
||||
self.0 * rhs.0 - self.1 * rhs.1,
|
||||
self.0 * rhs.1 + self.1 * rhs.0,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryInto<M31> for CM31 {
|
||||
type Error = ();
|
||||
|
||||
fn try_into(self) -> Result<M31, Self::Error> {
|
||||
if self.1 != M31::zero() {
|
||||
return Err(());
|
||||
}
|
||||
Ok(self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl FieldExpOps for CM31 {
|
||||
fn inverse(&self) -> Self {
|
||||
assert!(!self.is_zero(), "0 has no inverse");
|
||||
// 1 / (a + bi) = (a - bi) / (a^2 + b^2).
|
||||
Self(self.0, -self.1) * (self.0.square() + self.1.square()).inverse()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[macro_export]
|
||||
macro_rules! cm31 {
|
||||
($m0:expr, $m1:expr) => {
|
||||
CM31::from_u32_unchecked($m0, $m1)
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
|
||||
use super::CM31;
|
||||
use crate::core::fields::m31::P;
|
||||
use crate::core::fields::{FieldExpOps, IntoSlice};
|
||||
use crate::m31;
|
||||
|
||||
#[test]
|
||||
fn test_inverse() {
|
||||
let cm = cm31!(1, 2);
|
||||
let cm_inv = cm.inverse();
|
||||
assert_eq!(cm * cm_inv, cm31!(1, 0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ops() {
|
||||
let cm0 = cm31!(1, 2);
|
||||
let cm1 = cm31!(4, 5);
|
||||
let m = m31!(8);
|
||||
let cm = CM31::from(m);
|
||||
let cm0_x_cm1 = cm31!(P - 6, 13);
|
||||
|
||||
assert_eq!(cm0 + cm1, cm31!(5, 7));
|
||||
assert_eq!(cm1 + m, cm1 + cm);
|
||||
assert_eq!(cm0 * cm1, cm0_x_cm1);
|
||||
assert_eq!(cm1 * m, cm1 * cm);
|
||||
assert_eq!(-cm0, cm31!(P - 1, P - 2));
|
||||
assert_eq!(cm0 - cm1, cm31!(P - 3, P - 3));
|
||||
assert_eq!(cm1 - m, cm1 - cm);
|
||||
assert_eq!(cm0_x_cm1 / cm1, cm31!(1, 2));
|
||||
assert_eq!(cm1 / m, cm1 / cm);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_into_slice() {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let x = (0..100).map(|_| rng.gen()).collect::<Vec<CM31>>();
|
||||
|
||||
let slice = CM31::into_slice(&x);
|
||||
|
||||
for i in 0..100 {
|
||||
let corresponding_sub_slice = &slice[i * 8..(i + 1) * 8];
|
||||
assert_eq!(
|
||||
x[i],
|
||||
cm31!(
|
||||
u32::from_le_bytes(corresponding_sub_slice[..4].try_into().unwrap()),
|
||||
u32::from_le_bytes(corresponding_sub_slice[4..].try_into().unwrap())
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
258
Stwo_wrapper/crates/prover/src/core/fields/m31.rs
Normal file
258
Stwo_wrapper/crates/prover/src/core/fields/m31.rs
Normal file
@ -0,0 +1,258 @@
|
||||
use std::fmt::Display;
|
||||
use std::ops::{
|
||||
Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
|
||||
};
|
||||
|
||||
use bytemuck::{Pod, Zeroable};
|
||||
use rand::distributions::{Distribution, Standard};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::{ComplexConjugate, FieldExpOps};
|
||||
use crate::impl_field;
|
||||
pub const MODULUS_BITS: u32 = 31;
|
||||
pub const N_BYTES_FELT: usize = 4;
|
||||
pub const P: u32 = 2147483647; // 2 ** 31 - 1
|
||||
|
||||
#[repr(transparent)]
|
||||
#[derive(
|
||||
Copy,
|
||||
Clone,
|
||||
Debug,
|
||||
Default,
|
||||
PartialEq,
|
||||
Eq,
|
||||
PartialOrd,
|
||||
Ord,
|
||||
Hash,
|
||||
Pod,
|
||||
Zeroable,
|
||||
Serialize,
|
||||
Deserialize,
|
||||
)]
|
||||
pub struct M31(pub u32);
|
||||
pub type BaseField = M31;
|
||||
|
||||
impl_field!(M31, P);
|
||||
|
||||
impl M31 {
|
||||
/// Returns `val % P` when `val` is in the range `[0, 2P)`.
|
||||
///
|
||||
/// ```
|
||||
/// use stwo_prover::core::fields::m31::{M31, P};
|
||||
///
|
||||
/// let val = 2 * P - 19;
|
||||
/// assert_eq!(M31::partial_reduce(val), M31::from(P - 19));
|
||||
/// ```
|
||||
pub fn partial_reduce(val: u32) -> Self {
|
||||
Self(val.checked_sub(P).unwrap_or(val))
|
||||
}
|
||||
|
||||
/// Returns `val % P` when `val` is in the range `[0, P^2)`.
|
||||
///
|
||||
/// ```
|
||||
/// use stwo_prover::core::fields::m31::{M31, P};
|
||||
///
|
||||
/// let val = (P as u64).pow(2) - 19;
|
||||
/// assert_eq!(M31::reduce(val), M31::from(P - 19));
|
||||
/// ```
|
||||
pub fn reduce(val: u64) -> Self {
|
||||
Self((((((val >> MODULUS_BITS) + val + 1) >> MODULUS_BITS) + val) & (P as u64)) as u32)
|
||||
}
|
||||
|
||||
pub const fn from_u32_unchecked(arg: u32) -> Self {
|
||||
Self(arg)
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for M31 {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Add for M31 {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
Self::partial_reduce(self.0 + rhs.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Neg for M31 {
|
||||
type Output = Self;
|
||||
|
||||
fn neg(self) -> Self::Output {
|
||||
Self::partial_reduce(P - self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Sub for M31 {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
Self::partial_reduce(self.0 + P - rhs.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Mul for M31 {
|
||||
type Output = Self;
|
||||
|
||||
fn mul(self, rhs: Self) -> Self::Output {
|
||||
Self::reduce((self.0 as u64) * (rhs.0 as u64))
|
||||
}
|
||||
}
|
||||
|
||||
impl FieldExpOps for M31 {
|
||||
/// ```
|
||||
/// use num_traits::One;
|
||||
/// use stwo_prover::core::fields::m31::BaseField;
|
||||
/// use stwo_prover::core::fields::FieldExpOps;
|
||||
///
|
||||
/// let v = BaseField::from(19);
|
||||
/// assert_eq!(v.inverse() * v, BaseField::one());
|
||||
/// ```
|
||||
fn inverse(&self) -> Self {
|
||||
assert!(!self.is_zero(), "0 has no inverse");
|
||||
pow2147483645(*self)
|
||||
}
|
||||
}
|
||||
|
||||
impl ComplexConjugate for M31 {
|
||||
fn complex_conjugate(&self) -> Self {
|
||||
*self
|
||||
}
|
||||
}
|
||||
|
||||
impl One for M31 {
|
||||
fn one() -> Self {
|
||||
Self(1)
|
||||
}
|
||||
}
|
||||
|
||||
impl Zero for M31 {
|
||||
fn zero() -> Self {
|
||||
Self(0)
|
||||
}
|
||||
|
||||
fn is_zero(&self) -> bool {
|
||||
*self == Self::zero()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<usize> for M31 {
|
||||
fn from(value: usize) -> Self {
|
||||
M31::reduce(value.try_into().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<u32> for M31 {
|
||||
fn from(value: u32) -> Self {
|
||||
M31::reduce(value.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<i32> for M31 {
|
||||
fn from(value: i32) -> Self {
|
||||
M31::reduce(value.try_into().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
impl Distribution<M31> for Standard {
|
||||
// Not intended for cryptographic use. Should only be used in tests and benchmarks.
|
||||
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> M31 {
|
||||
M31(rng.gen_range(0..P))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[macro_export]
|
||||
macro_rules! m31 {
|
||||
($m:expr) => {
|
||||
$crate::core::fields::m31::M31::from_u32_unchecked($m)
|
||||
};
|
||||
}
|
||||
|
||||
/// Computes `v^((2^31-1)-2)`.
|
||||
///
|
||||
/// Computes the multiplicative inverse of [`M31`] elements with 37 multiplications vs naive 60
|
||||
/// multiplications. Made generic to support both vectorized and non-vectorized implementations.
|
||||
/// Multiplication tree found with [addchain](https://github.com/mmcloughlin/addchain).
|
||||
///
|
||||
/// ```
|
||||
/// use stwo_prover::core::fields::m31::{pow2147483645, BaseField};
|
||||
/// use stwo_prover::core::fields::FieldExpOps;
|
||||
///
|
||||
/// let v = BaseField::from(19);
|
||||
/// assert_eq!(pow2147483645(v), v.pow(2147483645));
|
||||
/// ```
|
||||
pub fn pow2147483645<T: FieldExpOps>(v: T) -> T {
|
||||
let t0 = sqn::<2, T>(v) * v;
|
||||
let t1 = sqn::<1, T>(t0) * t0;
|
||||
let t2 = sqn::<3, T>(t1) * t0;
|
||||
let t3 = sqn::<1, T>(t2) * t0;
|
||||
let t4 = sqn::<8, T>(t3) * t3;
|
||||
let t5 = sqn::<8, T>(t4) * t3;
|
||||
sqn::<7, T>(t5) * t2
|
||||
}
|
||||
|
||||
/// Computes `v^(2*n)`.
|
||||
fn sqn<const N: usize, T: FieldExpOps>(mut v: T) -> T {
|
||||
for _ in 0..N {
|
||||
v = v.square();
|
||||
}
|
||||
v
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
|
||||
use super::{M31, P};
|
||||
use crate::core::fields::IntoSlice;
|
||||
|
||||
fn mul_p(a: u32, b: u32) -> u32 {
|
||||
((a as u64 * b as u64) % P as u64) as u32
|
||||
}
|
||||
|
||||
fn add_p(a: u32, b: u32) -> u32 {
|
||||
(a + b) % P
|
||||
}
|
||||
|
||||
fn neg_p(a: u32) -> u32 {
|
||||
if a == 0 {
|
||||
0
|
||||
} else {
|
||||
P - a
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_basic_ops() {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
for _ in 0..10000 {
|
||||
let x: u32 = rng.gen::<u32>() % P;
|
||||
let y: u32 = rng.gen::<u32>() % P;
|
||||
assert_eq!(m31!(add_p(x, y)), m31!(x) + m31!(y));
|
||||
assert_eq!(m31!(mul_p(x, y)), m31!(x) * m31!(y));
|
||||
assert_eq!(m31!(neg_p(x)), -m31!(x));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_into_slice() {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let x = (0..100).map(|_| rng.gen()).collect::<Vec<M31>>();
|
||||
|
||||
let slice = M31::into_slice(&x);
|
||||
|
||||
for i in 0..100 {
|
||||
assert_eq!(
|
||||
x[i],
|
||||
m31!(u32::from_le_bytes(
|
||||
slice[i * 4..(i + 1) * 4].try_into().unwrap()
|
||||
))
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
489
Stwo_wrapper/crates/prover/src/core/fields/mod.rs
Normal file
489
Stwo_wrapper/crates/prover/src/core/fields/mod.rs
Normal file
@ -0,0 +1,489 @@
|
||||
use std::fmt::{Debug, Display};
|
||||
use std::iter::{Product, Sum};
|
||||
use std::ops::{Mul, MulAssign, Neg};
|
||||
|
||||
use num_traits::{NumAssign, NumAssignOps, NumOps, One};
|
||||
|
||||
use super::backend::ColumnOps;
|
||||
|
||||
pub mod cm31;
|
||||
pub mod m31;
|
||||
pub mod qm31;
|
||||
pub mod secure_column;
|
||||
|
||||
pub trait FieldOps<F: Field>: ColumnOps<F> {
|
||||
// TODO(Ohad): change to use a mutable slice.
|
||||
fn batch_inverse(column: &Self::Column, dst: &mut Self::Column);
|
||||
}
|
||||
|
||||
pub trait FieldExpOps: Mul<Output = Self> + MulAssign + Sized + One + Copy {
|
||||
fn square(&self) -> Self {
|
||||
(*self) * (*self)
|
||||
}
|
||||
|
||||
fn pow(&self, exp: u128) -> Self {
|
||||
let mut res = Self::one();
|
||||
let mut base = *self;
|
||||
let mut exp = exp;
|
||||
while exp > 0 {
|
||||
if exp & 1 == 1 {
|
||||
res *= base;
|
||||
}
|
||||
base = base.square();
|
||||
exp >>= 1;
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
fn inverse(&self) -> Self;
|
||||
|
||||
/// Inverts a batch of elements using Montgomery's trick.
|
||||
fn batch_inverse(column: &[Self], dst: &mut [Self]) {
|
||||
const WIDTH: usize = 4;
|
||||
let n = column.len();
|
||||
debug_assert!(dst.len() >= n);
|
||||
|
||||
if n <= WIDTH || n % WIDTH != 0 {
|
||||
batch_inverse_classic(column, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
// First pass. Compute 'WIDTH' cumulative products in an interleaving fashion, reducing
|
||||
// instruction dependency and allowing better pipelining.
|
||||
let mut cum_prod: [Self; WIDTH] = [Self::one(); WIDTH];
|
||||
dst[..WIDTH].copy_from_slice(&cum_prod);
|
||||
for i in 0..n {
|
||||
cum_prod[i % WIDTH] *= column[i];
|
||||
dst[i] = cum_prod[i % WIDTH];
|
||||
}
|
||||
|
||||
// Inverse cumulative products.
|
||||
// Use classic batch inversion.
|
||||
let mut tail_inverses = [Self::one(); WIDTH];
|
||||
batch_inverse_classic(&dst[n - WIDTH..], &mut tail_inverses);
|
||||
|
||||
// Second pass.
|
||||
for i in (WIDTH..n).rev() {
|
||||
dst[i] = dst[i - WIDTH] * tail_inverses[i % WIDTH];
|
||||
tail_inverses[i % WIDTH] *= column[i];
|
||||
}
|
||||
dst[0..WIDTH].copy_from_slice(&tail_inverses);
|
||||
}
|
||||
}
|
||||
|
||||
/// Assumes dst is initialized and of the same length as column.
|
||||
fn batch_inverse_classic<T: FieldExpOps>(column: &[T], dst: &mut [T]) {
|
||||
let n = column.len();
|
||||
debug_assert!(dst.len() >= n);
|
||||
|
||||
dst[0] = column[0];
|
||||
// First pass.
|
||||
for i in 1..n {
|
||||
dst[i] = dst[i - 1] * column[i];
|
||||
}
|
||||
|
||||
// Inverse cumulative product.
|
||||
let mut curr_inverse = dst[n - 1].inverse();
|
||||
|
||||
// Second pass.
|
||||
for i in (1..n).rev() {
|
||||
dst[i] = dst[i - 1] * curr_inverse;
|
||||
curr_inverse *= column[i];
|
||||
}
|
||||
dst[0] = curr_inverse;
|
||||
}
|
||||
|
||||
pub trait Field:
|
||||
NumAssign
|
||||
+ Neg<Output = Self>
|
||||
+ ComplexConjugate
|
||||
+ Copy
|
||||
+ Default
|
||||
+ Debug
|
||||
+ Display
|
||||
+ PartialOrd
|
||||
+ Ord
|
||||
+ Send
|
||||
+ Sync
|
||||
+ Sized
|
||||
+ FieldExpOps
|
||||
+ Product
|
||||
+ for<'a> Product<&'a Self>
|
||||
+ Sum
|
||||
+ for<'a> Sum<&'a Self>
|
||||
{
|
||||
fn double(&self) -> Self {
|
||||
(*self) + (*self)
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// Do not use unless you are aware of the endianess in the platform you are compiling for, and the
|
||||
/// Field element's representation in memory.
|
||||
// TODO(Ohad): Do not compile on non-le targets.
|
||||
pub unsafe trait IntoSlice<T: Sized>: Sized {
|
||||
fn into_slice(sl: &[Self]) -> &[T] {
|
||||
unsafe {
|
||||
std::slice::from_raw_parts(
|
||||
sl.as_ptr() as *const T,
|
||||
std::mem::size_of_val(sl) / std::mem::size_of::<T>(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<F: Field> IntoSlice<u8> for F {}
|
||||
|
||||
pub trait ComplexConjugate {
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use stwo_prover::core::fields::m31::P;
|
||||
/// use stwo_prover::core::fields::qm31::QM31;
|
||||
/// use stwo_prover::core::fields::ComplexConjugate;
|
||||
///
|
||||
/// let x = QM31::from_u32_unchecked(1, 2, 3, 4);
|
||||
/// assert_eq!(
|
||||
/// x.complex_conjugate(),
|
||||
/// QM31::from_u32_unchecked(1, 2, P - 3, P - 4)
|
||||
/// );
|
||||
/// ```
|
||||
fn complex_conjugate(&self) -> Self;
|
||||
}
|
||||
|
||||
pub trait ExtensionOf<F: Field>: Field + From<F> + NumOps<F> + NumAssignOps<F> {
|
||||
const EXTENSION_DEGREE: usize;
|
||||
}
|
||||
|
||||
impl<F: Field> ExtensionOf<F> for F {
|
||||
const EXTENSION_DEGREE: usize = 1;
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! impl_field {
|
||||
($field_name: ty, $field_size: ident) => {
|
||||
use std::iter::{Product, Sum};
|
||||
|
||||
use num_traits::{Num, One, Zero};
|
||||
use $crate::core::fields::Field;
|
||||
|
||||
impl Num for $field_name {
|
||||
type FromStrRadixErr = Box<dyn std::error::Error>;
|
||||
|
||||
fn from_str_radix(_str: &str, _radix: u32) -> Result<Self, Self::FromStrRadixErr> {
|
||||
unimplemented!(
|
||||
"Num::from_str_radix is not implemented for {}",
|
||||
stringify!($field_name)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl Field for $field_name {}
|
||||
|
||||
impl AddAssign for $field_name {
|
||||
fn add_assign(&mut self, rhs: Self) {
|
||||
*self = *self + rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl SubAssign for $field_name {
|
||||
fn sub_assign(&mut self, rhs: Self) {
|
||||
*self = *self - rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl MulAssign for $field_name {
|
||||
fn mul_assign(&mut self, rhs: Self) {
|
||||
*self = *self * rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl Div for $field_name {
|
||||
type Output = Self;
|
||||
|
||||
#[allow(clippy::suspicious_arithmetic_impl)]
|
||||
fn div(self, rhs: Self) -> Self::Output {
|
||||
self * rhs.inverse()
|
||||
}
|
||||
}
|
||||
|
||||
impl DivAssign for $field_name {
|
||||
fn div_assign(&mut self, rhs: Self) {
|
||||
*self = *self / rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl Rem for $field_name {
|
||||
type Output = Self;
|
||||
|
||||
fn rem(self, _rhs: Self) -> Self::Output {
|
||||
unimplemented!("Rem is not implemented for {}", stringify!($field_name));
|
||||
}
|
||||
}
|
||||
|
||||
impl RemAssign for $field_name {
|
||||
fn rem_assign(&mut self, _rhs: Self) {
|
||||
unimplemented!(
|
||||
"RemAssign is not implemented for {}",
|
||||
stringify!($field_name)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl Product for $field_name {
|
||||
fn product<I>(mut iter: I) -> Self
|
||||
where
|
||||
I: Iterator<Item = Self>,
|
||||
{
|
||||
let first = iter.next().unwrap_or_else(Self::one);
|
||||
iter.fold(first, |a, b| a * b)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Product<&'a Self> for $field_name {
|
||||
fn product<I>(iter: I) -> Self
|
||||
where
|
||||
I: Iterator<Item = &'a Self>,
|
||||
{
|
||||
iter.map(|&v| v).product()
|
||||
}
|
||||
}
|
||||
|
||||
impl Sum for $field_name {
|
||||
fn sum<I>(mut iter: I) -> Self
|
||||
where
|
||||
I: Iterator<Item = Self>,
|
||||
{
|
||||
let first = iter.next().unwrap_or_else(Self::zero);
|
||||
iter.fold(first, |a, b| a + b)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Sum<&'a Self> for $field_name {
|
||||
fn sum<I>(iter: I) -> Self
|
||||
where
|
||||
I: Iterator<Item = &'a Self>,
|
||||
{
|
||||
iter.map(|&v| v).sum()
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Used to extend a field (with characteristic M31) by 2.
|
||||
#[macro_export]
|
||||
macro_rules! impl_extension_field {
|
||||
($field_name: ident, $extended_field_name: ty) => {
|
||||
use rand::distributions::{Distribution, Standard};
|
||||
use $crate::core::fields::ExtensionOf;
|
||||
|
||||
impl ExtensionOf<M31> for $field_name {
|
||||
const EXTENSION_DEGREE: usize =
|
||||
<$extended_field_name as ExtensionOf<M31>>::EXTENSION_DEGREE * 2;
|
||||
}
|
||||
|
||||
impl Add for $field_name {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
Self(self.0 + rhs.0, self.1 + rhs.1)
|
||||
}
|
||||
}
|
||||
|
||||
impl Neg for $field_name {
|
||||
type Output = Self;
|
||||
|
||||
fn neg(self) -> Self::Output {
|
||||
Self(-self.0, -self.1)
|
||||
}
|
||||
}
|
||||
|
||||
impl Sub for $field_name {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
Self(self.0 - rhs.0, self.1 - rhs.1)
|
||||
}
|
||||
}
|
||||
|
||||
impl One for $field_name {
|
||||
fn one() -> Self {
|
||||
Self(
|
||||
<$extended_field_name>::one(),
|
||||
<$extended_field_name>::zero(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Zero for $field_name {
|
||||
fn zero() -> Self {
|
||||
Self(
|
||||
<$extended_field_name>::zero(),
|
||||
<$extended_field_name>::zero(),
|
||||
)
|
||||
}
|
||||
|
||||
fn is_zero(&self) -> bool {
|
||||
*self == Self::zero()
|
||||
}
|
||||
}
|
||||
|
||||
impl Add<M31> for $field_name {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: M31) -> Self::Output {
|
||||
Self(self.0 + rhs, self.1)
|
||||
}
|
||||
}
|
||||
|
||||
impl Add<$field_name> for M31 {
|
||||
type Output = $field_name;
|
||||
|
||||
fn add(self, rhs: $field_name) -> Self::Output {
|
||||
rhs + self
|
||||
}
|
||||
}
|
||||
|
||||
impl Sub<M31> for $field_name {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, rhs: M31) -> Self::Output {
|
||||
Self(self.0 - rhs, self.1)
|
||||
}
|
||||
}
|
||||
|
||||
impl Sub<$field_name> for M31 {
|
||||
type Output = $field_name;
|
||||
|
||||
fn sub(self, rhs: $field_name) -> Self::Output {
|
||||
-rhs + self
|
||||
}
|
||||
}
|
||||
|
||||
impl Mul<M31> for $field_name {
|
||||
type Output = Self;
|
||||
|
||||
fn mul(self, rhs: M31) -> Self::Output {
|
||||
Self(self.0 * rhs, self.1 * rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl Mul<$field_name> for M31 {
|
||||
type Output = $field_name;
|
||||
|
||||
fn mul(self, rhs: $field_name) -> Self::Output {
|
||||
rhs * self
|
||||
}
|
||||
}
|
||||
|
||||
impl Div<M31> for $field_name {
|
||||
type Output = Self;
|
||||
|
||||
fn div(self, rhs: M31) -> Self::Output {
|
||||
Self(self.0 / rhs, self.1 / rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl Div<$field_name> for M31 {
|
||||
type Output = $field_name;
|
||||
|
||||
#[allow(clippy::suspicious_arithmetic_impl)]
|
||||
fn div(self, rhs: $field_name) -> Self::Output {
|
||||
rhs.inverse() * self
|
||||
}
|
||||
}
|
||||
|
||||
impl ComplexConjugate for $field_name {
|
||||
fn complex_conjugate(&self) -> Self {
|
||||
Self(self.0, -self.1)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<M31> for $field_name {
|
||||
fn from(x: M31) -> Self {
|
||||
Self(x.into(), <$extended_field_name>::zero())
|
||||
}
|
||||
}
|
||||
|
||||
impl AddAssign<M31> for $field_name {
|
||||
fn add_assign(&mut self, rhs: M31) {
|
||||
*self = *self + rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl SubAssign<M31> for $field_name {
|
||||
fn sub_assign(&mut self, rhs: M31) {
|
||||
*self = *self - rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl MulAssign<M31> for $field_name {
|
||||
fn mul_assign(&mut self, rhs: M31) {
|
||||
*self = *self * rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl DivAssign<M31> for $field_name {
|
||||
fn div_assign(&mut self, rhs: M31) {
|
||||
*self = *self / rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl Rem<M31> for $field_name {
|
||||
type Output = Self;
|
||||
|
||||
fn rem(self, _rhs: M31) -> Self::Output {
|
||||
unimplemented!("Rem is not implemented for {}", stringify!($field_name));
|
||||
}
|
||||
}
|
||||
|
||||
impl RemAssign<M31> for $field_name {
|
||||
fn rem_assign(&mut self, _rhs: M31) {
|
||||
unimplemented!(
|
||||
"RemAssign is not implemented for {}",
|
||||
stringify!($field_name)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl Distribution<$field_name> for Standard {
|
||||
// Not intended for cryptographic use. Should only be used in tests and benchmarks.
|
||||
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> $field_name {
|
||||
$field_name(rng.gen(), rng.gen())
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use num_traits::Zero;
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
|
||||
use crate::core::fields::m31::M31;
|
||||
use crate::core::fields::FieldExpOps;
|
||||
|
||||
#[test]
|
||||
fn test_slice_batch_inverse() {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let elements: [M31; 16] = rng.gen();
|
||||
let expected = elements.iter().map(|e| e.inverse()).collect::<Vec<_>>();
|
||||
let mut dst = [M31::zero(); 16];
|
||||
|
||||
M31::batch_inverse(&elements, &mut dst);
|
||||
|
||||
assert_eq!(expected, dst);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_slice_batch_inverse_wrong_dst_size() {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let elements: [M31; 16] = rng.gen();
|
||||
let mut dst = [M31::zero(); 15];
|
||||
|
||||
M31::batch_inverse(&elements, &mut dst);
|
||||
}
|
||||
}
|
||||
195
Stwo_wrapper/crates/prover/src/core/fields/qm31.rs
Normal file
195
Stwo_wrapper/crates/prover/src/core/fields/qm31.rs
Normal file
@ -0,0 +1,195 @@
|
||||
use std::fmt::{Debug, Display};
|
||||
use std::ops::{
|
||||
Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
|
||||
};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::secure_column::SECURE_EXTENSION_DEGREE;
|
||||
use super::{ComplexConjugate, FieldExpOps};
|
||||
use crate::core::fields::cm31::CM31;
|
||||
use crate::core::fields::m31::M31;
|
||||
use crate::{impl_extension_field, impl_field};
|
||||
|
||||
pub const P4: u128 = 21267647892944572736998860269687930881; // (2 ** 31 - 1) ** 4
|
||||
pub const R: CM31 = CM31::from_u32_unchecked(2, 1);
|
||||
|
||||
/// Extension field of CM31.
|
||||
/// Equivalent to CM31\[x\] over (x^2 - 2 - i) as the irreducible polynomial.
|
||||
/// Represented as ((a, b), (c, d)) of (a + bi) + (c + di)u.
|
||||
#[derive(Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)]
|
||||
pub struct QM31(pub CM31, pub CM31);
|
||||
pub type SecureField = QM31;
|
||||
|
||||
impl_field!(QM31, P4);
|
||||
impl_extension_field!(QM31, CM31);
|
||||
|
||||
impl QM31 {
|
||||
pub const fn from_u32_unchecked(a: u32, b: u32, c: u32, d: u32) -> Self {
|
||||
Self(
|
||||
CM31::from_u32_unchecked(a, b),
|
||||
CM31::from_u32_unchecked(c, d),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn from_m31(a: M31, b: M31, c: M31, d: M31) -> Self {
|
||||
Self(CM31::from_m31(a, b), CM31::from_m31(c, d))
|
||||
}
|
||||
|
||||
pub fn from_m31_array(array: [M31; SECURE_EXTENSION_DEGREE]) -> Self {
|
||||
Self::from_m31(array[0], array[1], array[2], array[3])
|
||||
}
|
||||
|
||||
pub fn to_m31_array(self) -> [M31; SECURE_EXTENSION_DEGREE] {
|
||||
[self.0 .0, self.0 .1, self.1 .0, self.1 .1]
|
||||
}
|
||||
|
||||
/// Returns the combined value, given the values of its composing base field polynomials at that
|
||||
/// point.
|
||||
pub fn from_partial_evals(evals: [Self; SECURE_EXTENSION_DEGREE]) -> Self {
|
||||
let mut res = evals[0];
|
||||
res += evals[1] * Self::from_u32_unchecked(0, 1, 0, 0);
|
||||
res += evals[2] * Self::from_u32_unchecked(0, 0, 1, 0);
|
||||
res += evals[3] * Self::from_u32_unchecked(0, 0, 0, 1);
|
||||
res
|
||||
}
|
||||
|
||||
// Note: Adding this as a Mul impl drives rust insane, and it tries to infer Qm31*Qm31 as
|
||||
// QM31*CM31.
|
||||
pub fn mul_cm31(self, rhs: CM31) -> Self {
|
||||
Self(self.0 * rhs, self.1 * rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for QM31 {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "({}) + ({})u", self.0, self.1)
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for QM31 {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "({}) + ({})u", self.0, self.1)
|
||||
}
|
||||
}
|
||||
|
||||
impl Mul for QM31 {
|
||||
type Output = Self;
|
||||
|
||||
fn mul(self, rhs: Self) -> Self::Output {
|
||||
// (a + bu) * (c + du) = (ac + rbd) + (ad + bc)u.
|
||||
Self(
|
||||
self.0 * rhs.0 + R * self.1 * rhs.1,
|
||||
self.0 * rhs.1 + self.1 * rhs.0,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<usize> for QM31 {
|
||||
fn from(value: usize) -> Self {
|
||||
M31::from(value).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<u32> for QM31 {
|
||||
fn from(value: u32) -> Self {
|
||||
M31::from(value).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<i32> for QM31 {
|
||||
fn from(value: i32) -> Self {
|
||||
M31::from(value).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryInto<M31> for QM31 {
|
||||
type Error = ();
|
||||
|
||||
fn try_into(self) -> Result<M31, Self::Error> {
|
||||
if self.1 != CM31::zero() {
|
||||
return Err(());
|
||||
}
|
||||
self.0.try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl FieldExpOps for QM31 {
|
||||
fn inverse(&self) -> Self {
|
||||
assert!(!self.is_zero(), "0 has no inverse");
|
||||
// (a + bu)^-1 = (a - bu) / (a^2 - (2+i)b^2).
|
||||
let b2 = self.1.square();
|
||||
let ib2 = CM31(-b2.1, b2.0);
|
||||
let denom = self.0.square() - (b2 + b2 + ib2);
|
||||
let denom_inverse = denom.inverse();
|
||||
Self(self.0 * denom_inverse, -self.1 * denom_inverse)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[macro_export]
|
||||
macro_rules! qm31 {
|
||||
($m0:expr, $m1:expr, $m2:expr, $m3:expr) => {{
|
||||
use $crate::core::fields::qm31::QM31;
|
||||
QM31::from_u32_unchecked($m0, $m1, $m2, $m3)
|
||||
}};
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use num_traits::One;
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
|
||||
use super::QM31;
|
||||
use crate::core::fields::m31::P;
|
||||
use crate::core::fields::{FieldExpOps, IntoSlice};
|
||||
use crate::m31;
|
||||
|
||||
#[test]
|
||||
fn test_inverse() {
|
||||
let qm = qm31!(1, 2, 3, 4);
|
||||
let qm_inv = qm.inverse();
|
||||
assert_eq!(qm * qm_inv, QM31::one());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ops() {
|
||||
let qm0 = qm31!(1, 2, 3, 4);
|
||||
let qm1 = qm31!(4, 5, 6, 7);
|
||||
let m = m31!(8);
|
||||
let qm = QM31::from(m);
|
||||
let qm0_x_qm1 = qm31!(P - 71, 93, P - 16, 50);
|
||||
|
||||
assert_eq!(qm0 + qm1, qm31!(5, 7, 9, 11));
|
||||
assert_eq!(qm1 + m, qm1 + qm);
|
||||
assert_eq!(qm0 * qm1, qm0_x_qm1);
|
||||
assert_eq!(qm1 * m, qm1 * qm);
|
||||
assert_eq!(-qm0, qm31!(P - 1, P - 2, P - 3, P - 4));
|
||||
assert_eq!(qm0 - qm1, qm31!(P - 3, P - 3, P - 3, P - 3));
|
||||
assert_eq!(qm1 - m, qm1 - qm);
|
||||
assert_eq!(qm0_x_qm1 / qm1, qm31!(1, 2, 3, 4));
|
||||
assert_eq!(qm1 / m, qm1 / qm);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_into_slice() {
|
||||
let mut rng = SmallRng::seed_from_u64(0);
|
||||
let x = (0..100).map(|_| rng.gen()).collect::<Vec<QM31>>();
|
||||
|
||||
let slice = QM31::into_slice(&x);
|
||||
|
||||
for i in 0..100 {
|
||||
let corresponding_sub_slice = &slice[i * 16..(i + 1) * 16];
|
||||
assert_eq!(
|
||||
x[i],
|
||||
qm31!(
|
||||
u32::from_le_bytes(corresponding_sub_slice[..4].try_into().unwrap()),
|
||||
u32::from_le_bytes(corresponding_sub_slice[4..8].try_into().unwrap()),
|
||||
u32::from_le_bytes(corresponding_sub_slice[8..12].try_into().unwrap()),
|
||||
u32::from_le_bytes(corresponding_sub_slice[12..16].try_into().unwrap())
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
111
Stwo_wrapper/crates/prover/src/core/fields/secure_column.rs
Normal file
111
Stwo_wrapper/crates/prover/src/core/fields/secure_column.rs
Normal file
@ -0,0 +1,111 @@
|
||||
use std::array;
|
||||
use std::iter::zip;
|
||||
|
||||
use super::m31::BaseField;
|
||||
use super::qm31::SecureField;
|
||||
use super::{ExtensionOf, FieldOps};
|
||||
use crate::core::backend::{Col, Column, CpuBackend};
|
||||
|
||||
pub const SECURE_EXTENSION_DEGREE: usize =
|
||||
<SecureField as ExtensionOf<BaseField>>::EXTENSION_DEGREE;
|
||||
|
||||
/// A column major array of `SECURE_EXTENSION_DEGREE` base field columns, that represents a column
|
||||
/// of secure field element coordinates.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct SecureColumnByCoords<B: FieldOps<BaseField>> {
|
||||
pub columns: [Col<B, BaseField>; SECURE_EXTENSION_DEGREE],
|
||||
}
|
||||
impl SecureColumnByCoords<CpuBackend> {
|
||||
// TODO(spapini): Remove when we no longer use CircleEvaluation<SecureField>.
|
||||
pub fn to_vec(&self) -> Vec<SecureField> {
|
||||
(0..self.len()).map(|i| self.at(i)).collect()
|
||||
}
|
||||
}
|
||||
impl<B: FieldOps<BaseField>> SecureColumnByCoords<B> {
|
||||
pub fn at(&self, index: usize) -> SecureField {
|
||||
SecureField::from_m31_array(std::array::from_fn(|i| self.columns[i].at(index)))
|
||||
}
|
||||
|
||||
pub fn zeros(len: usize) -> Self {
|
||||
Self {
|
||||
columns: std::array::from_fn(|_| Col::<B, BaseField>::zeros(len)),
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
pub unsafe fn uninitialized(len: usize) -> Self {
|
||||
Self {
|
||||
columns: std::array::from_fn(|_| Col::<B, BaseField>::uninitialized(len)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.columns[0].len()
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.columns[0].is_empty()
|
||||
}
|
||||
|
||||
pub fn to_cpu(&self) -> SecureColumnByCoords<CpuBackend> {
|
||||
SecureColumnByCoords {
|
||||
columns: self.columns.clone().map(|c| c.to_cpu()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set(&mut self, index: usize, value: SecureField) {
|
||||
let values = value.to_m31_array();
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
for i in 0..SECURE_EXTENSION_DEGREE {
|
||||
self.columns[i].set(index, values[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SecureColumnByCoordsIter<'a> {
|
||||
column: &'a SecureColumnByCoords<CpuBackend>,
|
||||
index: usize,
|
||||
}
|
||||
impl Iterator for SecureColumnByCoordsIter<'_> {
|
||||
type Item = SecureField;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.index < self.column.len() {
|
||||
let value = self.column.at(self.index);
|
||||
self.index += 1;
|
||||
Some(value)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<'a> IntoIterator for &'a SecureColumnByCoords<CpuBackend> {
|
||||
type Item = SecureField;
|
||||
type IntoIter = SecureColumnByCoordsIter<'a>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
SecureColumnByCoordsIter {
|
||||
column: self,
|
||||
index: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl FromIterator<SecureField> for SecureColumnByCoords<CpuBackend> {
|
||||
fn from_iter<I: IntoIterator<Item = SecureField>>(iter: I) -> Self {
|
||||
let values = iter.into_iter();
|
||||
let (lower_bound, _) = values.size_hint();
|
||||
let mut columns = array::from_fn(|_| Vec::with_capacity(lower_bound));
|
||||
|
||||
for value in values {
|
||||
let coords = value.to_m31_array();
|
||||
zip(&mut columns, coords).for_each(|(col, coord)| col.push(coord));
|
||||
}
|
||||
|
||||
SecureColumnByCoords { columns }
|
||||
}
|
||||
}
|
||||
impl From<SecureColumnByCoords<CpuBackend>> for Vec<SecureField> {
|
||||
fn from(column: SecureColumnByCoords<CpuBackend>) -> Self {
|
||||
column.into_iter().collect()
|
||||
}
|
||||
}
|
||||
1424
Stwo_wrapper/crates/prover/src/core/fri.rs
Normal file
1424
Stwo_wrapper/crates/prover/src/core/fri.rs
Normal file
File diff suppressed because it is too large
Load Diff
566
Stwo_wrapper/crates/prover/src/core/lookups/gkr_prover.rs
Normal file
566
Stwo_wrapper/crates/prover/src/core/lookups/gkr_prover.rs
Normal file
@ -0,0 +1,566 @@
|
||||
//! GKR batch prover for Grand Product and LogUp lookup arguments.
|
||||
use std::borrow::Cow;
|
||||
use std::iter::{successors, zip};
|
||||
use std::ops::Deref;
|
||||
|
||||
use educe::Educe;
|
||||
use itertools::Itertools;
|
||||
use num_traits::{One, Zero};
|
||||
use thiserror::Error;
|
||||
|
||||
use super::gkr_verifier::{GkrArtifact, GkrBatchProof, GkrMask};
|
||||
use super::mle::{Mle, MleOps};
|
||||
use super::sumcheck::MultivariatePolyOracle;
|
||||
use super::utils::{eq, random_linear_combination, UnivariatePoly};
|
||||
use crate::core::backend::{Col, Column, ColumnOps, CpuBackend};
|
||||
use crate::core::channel::Channel;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::fields::{Field, FieldExpOps};
|
||||
use crate::core::lookups::sumcheck;
|
||||
|
||||
pub trait GkrOps: MleOps<BaseField> + MleOps<SecureField> {
|
||||
/// Returns evaluations `eq(x, y) * v` for all `x` in `{0, 1}^n`.
|
||||
///
|
||||
/// Note [`Mle`] stores values in bit-reversed order.
|
||||
///
|
||||
/// [`eq(x, y)`]: crate::core::lookups::utils::eq
|
||||
fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Mle<Self, SecureField>;
|
||||
|
||||
/// Generates the next GKR layer from the current one.
|
||||
fn next_layer(layer: &Layer<Self>) -> Layer<Self>;
|
||||
|
||||
/// Returns univariate polynomial `f(t) = sum_x h(t, x)` for all `x` in the boolean hypercube.
|
||||
///
|
||||
/// `claim` equals `f(0) + f(1)`.
|
||||
///
|
||||
/// For more context see docs of [`MultivariatePolyOracle::sum_as_poly_in_first_variable()`].
|
||||
fn sum_as_poly_in_first_variable(
|
||||
h: &GkrMultivariatePolyOracle<'_, Self>,
|
||||
claim: SecureField,
|
||||
) -> UnivariatePoly<SecureField>;
|
||||
}
|
||||
|
||||
/// Stores evaluations of [`eq(x, y)`] on all boolean hypercube points of the form
|
||||
/// `x = (0, x_1, ..., x_{n-1})`.
|
||||
///
|
||||
/// Evaluations are stored in bit-reversed order i.e. `evals[0] = eq((0, ..., 0, 0), y)`,
|
||||
/// `evals[1] = eq((0, ..., 0, 1), y)`, etc.
|
||||
///
|
||||
/// [`eq(x, y)`]: crate::core::lookups::utils::eq
|
||||
#[derive(Educe)]
|
||||
#[educe(Debug, Clone)]
|
||||
pub struct EqEvals<B: ColumnOps<SecureField>> {
|
||||
y: Vec<SecureField>,
|
||||
evals: Mle<B, SecureField>,
|
||||
}
|
||||
|
||||
impl<B: GkrOps> EqEvals<B> {
|
||||
pub fn generate(y: &[SecureField]) -> Self {
|
||||
let y = y.to_vec();
|
||||
|
||||
if y.is_empty() {
|
||||
let evals = Mle::new([SecureField::one()].into_iter().collect());
|
||||
return Self { evals, y };
|
||||
}
|
||||
|
||||
let evals = B::gen_eq_evals(&y[1..], eq(&[SecureField::zero()], &[y[0]]));
|
||||
assert_eq!(evals.len(), 1 << (y.len() - 1));
|
||||
Self { evals, y }
|
||||
}
|
||||
|
||||
/// Returns the fixed vector `y` used to generate the evaluations.
|
||||
pub fn y(&self) -> &[SecureField] {
|
||||
&self.y
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: ColumnOps<SecureField>> Deref for EqEvals<B> {
|
||||
type Target = Col<B, SecureField>;
|
||||
|
||||
fn deref(&self) -> &Col<B, SecureField> {
|
||||
&self.evals
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a layer in a binary tree structured GKR circuit.
|
||||
///
|
||||
/// Layers can contain multiple columns, for example [LogUp] which has separate columns for
|
||||
/// numerators and denominators.
|
||||
///
|
||||
/// [LogUp]: https://eprint.iacr.org/2023/1284.pdf
|
||||
#[derive(Educe)]
|
||||
#[educe(Debug, Clone)]
|
||||
pub enum Layer<B: GkrOps> {
|
||||
GrandProduct(Mle<B, SecureField>),
|
||||
LogUpGeneric {
|
||||
numerators: Mle<B, SecureField>,
|
||||
denominators: Mle<B, SecureField>,
|
||||
},
|
||||
LogUpMultiplicities {
|
||||
numerators: Mle<B, BaseField>,
|
||||
denominators: Mle<B, SecureField>,
|
||||
},
|
||||
/// All numerators implicitly equal "1".
|
||||
LogUpSingles {
|
||||
denominators: Mle<B, SecureField>,
|
||||
},
|
||||
}
|
||||
|
||||
impl<B: GkrOps> Layer<B> {
|
||||
/// Returns the number of variables used to interpolate the layer's gate values.
|
||||
pub fn n_variables(&self) -> usize {
|
||||
match self {
|
||||
Self::GrandProduct(mle)
|
||||
| Self::LogUpSingles { denominators: mle }
|
||||
| Self::LogUpMultiplicities {
|
||||
denominators: mle, ..
|
||||
}
|
||||
| Self::LogUpGeneric {
|
||||
denominators: mle, ..
|
||||
} => mle.n_variables(),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_output_layer(&self) -> bool {
|
||||
self.n_variables() == 0
|
||||
}
|
||||
|
||||
/// Produces the next layer from the current layer.
|
||||
///
|
||||
/// The next layer is strictly half the size of the current layer.
|
||||
/// Returns [`None`] if called on an output layer.
|
||||
pub fn next_layer(&self) -> Option<Self> {
|
||||
if self.is_output_layer() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(B::next_layer(self))
|
||||
}
|
||||
|
||||
/// Returns each column output if the layer is an output layer, otherwise returns an `Err`.
|
||||
fn try_into_output_layer_values(self) -> Result<Vec<SecureField>, NotOutputLayerError> {
|
||||
if !self.is_output_layer() {
|
||||
return Err(NotOutputLayerError);
|
||||
}
|
||||
|
||||
Ok(match self {
|
||||
Layer::LogUpSingles { denominators } => {
|
||||
let numerator = SecureField::one();
|
||||
let denominator = denominators.at(0);
|
||||
vec![numerator, denominator]
|
||||
}
|
||||
Layer::LogUpMultiplicities {
|
||||
numerators,
|
||||
denominators,
|
||||
} => {
|
||||
let numerator = numerators.at(0).into();
|
||||
let denominator = denominators.at(0);
|
||||
vec![numerator, denominator]
|
||||
}
|
||||
Layer::LogUpGeneric {
|
||||
numerators,
|
||||
denominators,
|
||||
} => {
|
||||
let numerator = numerators.at(0);
|
||||
let denominator = denominators.at(0);
|
||||
vec![numerator, denominator]
|
||||
}
|
||||
Layer::GrandProduct(col) => {
|
||||
vec![col.at(0)]
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns a transformed layer with the first variable of each column fixed to `assignment`.
|
||||
fn fix_first_variable(self, x0: SecureField) -> Self {
|
||||
if self.n_variables() == 0 {
|
||||
return self;
|
||||
}
|
||||
|
||||
match self {
|
||||
Self::GrandProduct(mle) => Self::GrandProduct(mle.fix_first_variable(x0)),
|
||||
Self::LogUpGeneric {
|
||||
numerators,
|
||||
denominators,
|
||||
} => Self::LogUpGeneric {
|
||||
numerators: numerators.fix_first_variable(x0),
|
||||
denominators: denominators.fix_first_variable(x0),
|
||||
},
|
||||
Self::LogUpMultiplicities {
|
||||
numerators,
|
||||
denominators,
|
||||
} => Self::LogUpGeneric {
|
||||
numerators: numerators.fix_first_variable(x0),
|
||||
denominators: denominators.fix_first_variable(x0),
|
||||
},
|
||||
Self::LogUpSingles { denominators } => Self::LogUpSingles {
|
||||
denominators: denominators.fix_first_variable(x0),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the next GKR layer evaluation as a multivariate polynomial which uses this GKR
|
||||
/// layer as input.
|
||||
///
|
||||
/// Layers can contain multiple columns `c_0, ..., c_{n-1}` with multivariate polynomial `g_i`
|
||||
/// representing[^note] `c_i` in the next layer. These polynomials must be combined with
|
||||
/// `lambda` into a single polynomial `h = g_0 + lambda * g_1 + ... + lambda^(n-1) *
|
||||
/// g_{n-1}`. The oracle for `h` should be returned.
|
||||
///
|
||||
/// # Optimization: precomputed [`eq(x, y)`] evals
|
||||
///
|
||||
/// Let `y` be a fixed vector of length `m` and let `z` be a subvector comprising of the
|
||||
/// last `k` elements of `y`. `h(x)` **must** equal some multivariate polynomial of the form
|
||||
/// `eq(x, z) * p(x)`. A common operation will be computing the univariate polynomial `f(t) =
|
||||
/// sum_x h(t, x)` for `x` in the boolean hypercube `{0, 1}^(k-1)`.
|
||||
///
|
||||
/// `eq_evals` stores evaluations of `eq((0, x), y)` for `x` in a potentially extended boolean
|
||||
/// hypercube `{0, 1}^{m-1}`. These evaluations, on the extended hypercube, can be used directly
|
||||
/// in computing the sums of `h(x)`, however a correction factor must be applied to the final
|
||||
/// sum which is handled by [`correct_sum_as_poly_in_first_variable()`] in `O(m)`.
|
||||
///
|
||||
/// Being able to compute sums of `h(x)` using `eq_evals` in this way leads to a more efficient
|
||||
/// implementation because the prover only has to generate `eq_evals` once for an entire batch
|
||||
/// of multiple GKR layer instances.
|
||||
///
|
||||
/// [`eq(x, y)`]: crate::core::lookups::utils::eq
|
||||
/// [^note]: By "representing" we mean `g_i` agrees with the next layer's `c_i` on the boolean
|
||||
/// hypercube that interpolates `c_i`.
|
||||
fn into_multivariate_poly(
|
||||
self,
|
||||
lambda: SecureField,
|
||||
eq_evals: &EqEvals<B>,
|
||||
) -> GkrMultivariatePolyOracle<'_, B> {
|
||||
GkrMultivariatePolyOracle {
|
||||
eq_evals: Cow::Borrowed(eq_evals),
|
||||
input_layer: self,
|
||||
eq_fixed_var_correction: SecureField::one(),
|
||||
lambda,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a copy of this layer with the [`CpuBackend`].
|
||||
///
|
||||
/// This operation is expensive but can be useful for small traces that are difficult to handle
|
||||
/// depending on the backend. For example, the SIMD backend offloads to the CPU backend when
|
||||
/// trace length becomes smaller than the SIMD lane count.
|
||||
pub fn to_cpu(&self) -> Layer<CpuBackend> {
|
||||
match self {
|
||||
Layer::GrandProduct(mle) => Layer::GrandProduct(Mle::new(mle.to_cpu())),
|
||||
Layer::LogUpGeneric {
|
||||
numerators,
|
||||
denominators,
|
||||
} => Layer::LogUpGeneric {
|
||||
numerators: Mle::new(numerators.to_cpu()),
|
||||
denominators: Mle::new(denominators.to_cpu()),
|
||||
},
|
||||
Layer::LogUpMultiplicities {
|
||||
numerators,
|
||||
denominators,
|
||||
} => Layer::LogUpMultiplicities {
|
||||
numerators: Mle::new(numerators.to_cpu()),
|
||||
denominators: Mle::new(denominators.to_cpu()),
|
||||
},
|
||||
Layer::LogUpSingles { denominators } => Layer::LogUpSingles {
|
||||
denominators: Mle::new(denominators.to_cpu()),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct NotOutputLayerError;
|
||||
|
||||
/// Multivariate polynomial `P` that expresses the relation between two consecutive GKR layers.
|
||||
///
|
||||
/// When the input layer is [`Layer::GrandProduct`] (represented by multilinear column `inp`)
|
||||
/// the polynomial represents:
|
||||
///
|
||||
/// ```text
|
||||
/// P(x) = eq(x, y) * inp(x, 0) * inp(x, 1)
|
||||
/// ```
|
||||
///
|
||||
/// When the input layer is LogUp (represented by multilinear columns `inp_numer` and
|
||||
/// `inp_denom`) the polynomial represents:
|
||||
///
|
||||
/// ```text
|
||||
/// numer(x) = inp_numer(x, 0) * inp_denom(x, 1) + inp_numer(x, 1) * inp_denom(x, 0)
|
||||
/// denom(x) = inp_denom(x, 0) * inp_denom(x, 1)
|
||||
///
|
||||
/// P(x) = eq(x, y) * (numer(x) + lambda * denom(x))
|
||||
/// ```
|
||||
pub struct GkrMultivariatePolyOracle<'a, B: GkrOps> {
|
||||
/// `eq_evals` passed by `Layer::into_multivariate_poly()`.
|
||||
pub eq_evals: Cow<'a, EqEvals<B>>,
|
||||
pub input_layer: Layer<B>,
|
||||
pub eq_fixed_var_correction: SecureField,
|
||||
/// Used by LogUp to perform a random linear combination of the numerators and denominators.
|
||||
pub lambda: SecureField,
|
||||
}
|
||||
|
||||
impl<'a, B: GkrOps> MultivariatePolyOracle for GkrMultivariatePolyOracle<'a, B> {
|
||||
fn n_variables(&self) -> usize {
|
||||
self.input_layer.n_variables() - 1
|
||||
}
|
||||
|
||||
fn sum_as_poly_in_first_variable(&self, claim: SecureField) -> UnivariatePoly<SecureField> {
|
||||
B::sum_as_poly_in_first_variable(self, claim)
|
||||
}
|
||||
|
||||
fn fix_first_variable(self, challenge: SecureField) -> Self {
|
||||
if self.is_constant() {
|
||||
return self;
|
||||
}
|
||||
|
||||
let z0 = self.eq_evals.y()[self.eq_evals.y().len() - self.n_variables()];
|
||||
let eq_fixed_var_correction = self.eq_fixed_var_correction * eq(&[challenge], &[z0]);
|
||||
|
||||
Self {
|
||||
eq_evals: self.eq_evals,
|
||||
eq_fixed_var_correction,
|
||||
input_layer: self.input_layer.fix_first_variable(challenge),
|
||||
lambda: self.lambda,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, B: GkrOps> GkrMultivariatePolyOracle<'a, B> {
|
||||
fn is_constant(&self) -> bool {
|
||||
self.n_variables() == 0
|
||||
}
|
||||
|
||||
/// Returns all input layer columns restricted to a line.
|
||||
///
|
||||
/// Let `l` be the line satisfying `l(0) = b*` and `l(1) = c*`. Oracles that represent constants
|
||||
/// are expressed by values `c_i(b*)` and `c_i(c*)` where `c_i` represents the input GKR layer's
|
||||
/// `i`th column (for binary tree GKR `b* = (r, 0)`, `c* = (r, 1)`).
|
||||
///
|
||||
/// If this oracle represents a constant, then each `c_i` restricted to `l` is returned.
|
||||
/// Otherwise, an [`Err`] is returned.
|
||||
///
|
||||
/// For more context see <https://people.cs.georgetown.edu/jthaler/ProofsArgsAndZK.pdf> page 64.
|
||||
fn try_into_mask(self) -> Result<GkrMask, NotConstantPolyError> {
|
||||
if !self.is_constant() {
|
||||
return Err(NotConstantPolyError);
|
||||
}
|
||||
|
||||
let columns = match self.input_layer {
|
||||
Layer::GrandProduct(mle) => vec![mle.to_cpu().try_into().unwrap()],
|
||||
Layer::LogUpGeneric {
|
||||
numerators,
|
||||
denominators,
|
||||
} => {
|
||||
let numerators = numerators.to_cpu().try_into().unwrap();
|
||||
let denominators = denominators.to_cpu().try_into().unwrap();
|
||||
vec![numerators, denominators]
|
||||
}
|
||||
// Should never get called.
|
||||
Layer::LogUpMultiplicities { .. } => unimplemented!(),
|
||||
Layer::LogUpSingles { denominators } => {
|
||||
let numerators = [SecureField::one(); 2];
|
||||
let denominators = denominators.to_cpu().try_into().unwrap();
|
||||
vec![numerators, denominators]
|
||||
}
|
||||
};
|
||||
|
||||
Ok(GkrMask::new(columns))
|
||||
}
|
||||
|
||||
/// Returns a copy of this oracle with the [`CpuBackend`].
|
||||
///
|
||||
/// This operation is expensive but can be useful for small oracles that are difficult to handle
|
||||
/// depending on the backend. For example, the SIMD backend offloads to the CPU backend when
|
||||
/// trace length becomes smaller than the SIMD lane count.
|
||||
pub fn to_cpu(&self) -> GkrMultivariatePolyOracle<'a, CpuBackend> {
|
||||
// TODO(andrew): This block is not ideal.
|
||||
let n_eq_evals = 1 << (self.n_variables() - 1);
|
||||
let eq_evals = Cow::Owned(EqEvals {
|
||||
evals: Mle::new((0..n_eq_evals).map(|i| self.eq_evals.at(i)).collect()),
|
||||
y: self.eq_evals.y.to_vec(),
|
||||
});
|
||||
|
||||
GkrMultivariatePolyOracle {
|
||||
eq_evals,
|
||||
eq_fixed_var_correction: self.eq_fixed_var_correction,
|
||||
input_layer: self.input_layer.to_cpu(),
|
||||
lambda: self.lambda,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Error returned when a polynomial is expected to be constant but it is not.
|
||||
#[derive(Debug, Error)]
|
||||
#[error("polynomial is not constant")]
|
||||
pub struct NotConstantPolyError;
|
||||
|
||||
/// Batch proves lookup circuits with GKR.
|
||||
///
|
||||
/// The input layers should be committed to the channel before calling this function.
|
||||
// GKR algorithm: <https://people.cs.georgetown.edu/jthaler/ProofsArgsAndZK.pdf> (page 64)
|
||||
pub fn prove_batch<B: GkrOps>(
|
||||
channel: &mut impl Channel,
|
||||
input_layer_by_instance: Vec<Layer<B>>,
|
||||
) -> (GkrBatchProof, GkrArtifact) {
|
||||
let n_instances = input_layer_by_instance.len();
|
||||
let n_layers_by_instance = input_layer_by_instance
|
||||
.iter()
|
||||
.map(|l| l.n_variables())
|
||||
.collect_vec();
|
||||
let n_layers = *n_layers_by_instance.iter().max().unwrap();
|
||||
|
||||
// Evaluate all instance circuits and collect the layer values.
|
||||
let mut layers_by_instance = input_layer_by_instance
|
||||
.into_iter()
|
||||
.map(|input_layer| gen_layers(input_layer).into_iter().rev())
|
||||
.collect_vec();
|
||||
|
||||
let mut output_claims_by_instance = vec![None; n_instances];
|
||||
let mut layer_masks_by_instance = (0..n_instances).map(|_| Vec::new()).collect_vec();
|
||||
let mut sumcheck_proofs = Vec::new();
|
||||
|
||||
let mut ood_point = Vec::new();
|
||||
let mut claims_to_verify_by_instance = vec![None; n_instances];
|
||||
|
||||
for layer in 0..n_layers {
|
||||
let n_remaining_layers = n_layers - layer;
|
||||
|
||||
// Check all the instances for output layers.
|
||||
for (instance, layers) in layers_by_instance.iter_mut().enumerate() {
|
||||
if n_layers_by_instance[instance] == n_remaining_layers {
|
||||
let output_layer = layers.next().unwrap();
|
||||
let output_layer_values = output_layer.try_into_output_layer_values().unwrap();
|
||||
claims_to_verify_by_instance[instance] = Some(output_layer_values.clone());
|
||||
output_claims_by_instance[instance] = Some(output_layer_values);
|
||||
}
|
||||
}
|
||||
|
||||
// Seed the channel with layer claims.
|
||||
for claims_to_verify in claims_to_verify_by_instance.iter().flatten() {
|
||||
channel.mix_felts(claims_to_verify);
|
||||
}
|
||||
|
||||
let eq_evals = EqEvals::generate(&ood_point);
|
||||
let sumcheck_alpha = channel.draw_felt();
|
||||
let instance_lambda = channel.draw_felt();
|
||||
|
||||
let mut sumcheck_oracles = Vec::new();
|
||||
let mut sumcheck_claims = Vec::new();
|
||||
let mut sumcheck_instances = Vec::new();
|
||||
|
||||
// Create the multivariate polynomial oracles used with sumcheck.
|
||||
for (instance, claims_to_verify) in claims_to_verify_by_instance.iter().enumerate() {
|
||||
if let Some(claims_to_verify) = claims_to_verify {
|
||||
let layer = layers_by_instance[instance].next().unwrap();
|
||||
sumcheck_oracles.push(layer.into_multivariate_poly(instance_lambda, &eq_evals));
|
||||
sumcheck_claims.push(random_linear_combination(claims_to_verify, instance_lambda));
|
||||
sumcheck_instances.push(instance);
|
||||
}
|
||||
}
|
||||
|
||||
let (sumcheck_proof, sumcheck_ood_point, constant_poly_oracles, _) =
|
||||
sumcheck::prove_batch(sumcheck_claims, sumcheck_oracles, sumcheck_alpha, channel);
|
||||
|
||||
sumcheck_proofs.push(sumcheck_proof);
|
||||
|
||||
let masks = constant_poly_oracles
|
||||
.into_iter()
|
||||
.map(|oracle| oracle.try_into_mask().unwrap())
|
||||
.collect_vec();
|
||||
|
||||
// Seed the channel with the layer masks.
|
||||
for (&instance, mask) in zip(&sumcheck_instances, &masks) {
|
||||
channel.mix_felts(mask.columns().flatten());
|
||||
layer_masks_by_instance[instance].push(mask.clone());
|
||||
}
|
||||
|
||||
let challenge = channel.draw_felt();
|
||||
ood_point = sumcheck_ood_point;
|
||||
ood_point.push(challenge);
|
||||
|
||||
// Set the claims to prove in the layer above.
|
||||
for (instance, mask) in zip(sumcheck_instances, masks) {
|
||||
claims_to_verify_by_instance[instance] = Some(mask.reduce_at_point(challenge));
|
||||
}
|
||||
}
|
||||
|
||||
let output_claims_by_instance = output_claims_by_instance
|
||||
.into_iter()
|
||||
.map(Option::unwrap)
|
||||
.collect();
|
||||
|
||||
let claims_to_verify_by_instance = claims_to_verify_by_instance
|
||||
.into_iter()
|
||||
.map(Option::unwrap)
|
||||
.collect();
|
||||
|
||||
let proof = GkrBatchProof {
|
||||
sumcheck_proofs,
|
||||
layer_masks_by_instance,
|
||||
output_claims_by_instance,
|
||||
};
|
||||
|
||||
let artifact = GkrArtifact {
|
||||
ood_point,
|
||||
claims_to_verify_by_instance,
|
||||
n_variables_by_instance: n_layers_by_instance,
|
||||
};
|
||||
|
||||
(proof, artifact)
|
||||
}
|
||||
|
||||
/// Executes the GKR circuit on the input layer and returns all the circuit's layers.
|
||||
fn gen_layers<B: GkrOps>(input_layer: Layer<B>) -> Vec<Layer<B>> {
|
||||
let n_variables = input_layer.n_variables();
|
||||
let layers = successors(Some(input_layer), |layer| layer.next_layer()).collect_vec();
|
||||
assert_eq!(layers.len(), n_variables + 1);
|
||||
layers
|
||||
}
|
||||
|
||||
/// Computes `r(t) = sum_x eq((t, x), y[-k:]) * p(t, x)` from evaluations of
|
||||
/// `f(t) = sum_x eq(({0}^(n - k), 0, x), y) * p(t, x)`.
|
||||
///
|
||||
/// Note `claim` must equal `r(0) + r(1)` and `r` must have degree <= 3.
|
||||
///
|
||||
/// For more context see `Layer::into_multivariate_poly()` docs.
|
||||
/// See also <https://ia.cr/2024/108> (section 3.2).
|
||||
pub fn correct_sum_as_poly_in_first_variable(
|
||||
f_at_0: SecureField,
|
||||
f_at_2: SecureField,
|
||||
claim: SecureField,
|
||||
y: &[SecureField],
|
||||
k: usize,
|
||||
) -> UnivariatePoly<SecureField> {
|
||||
assert_ne!(k, 0);
|
||||
let n = y.len();
|
||||
assert!(k <= n);
|
||||
|
||||
// We evaluated `f(0)` and `f(2)` - the inputs.
|
||||
// We want to compute `r(t) = f(t) * eq(t, y[n - k]) / eq(0, y[:n - k + 1])`.
|
||||
let a_const = eq(&vec![SecureField::zero(); n - k + 1], &y[..n - k + 1]).inverse();
|
||||
|
||||
// Find the additional root of `r(t)`, by finding the root of `eq(t, y[n - k])`:
|
||||
// 0 = eq(t, y[n - k])
|
||||
// = t * y[n - k] + (1 - t)(1 - y[n - k])
|
||||
// = 1 - y[n - k] - t(1 - 2 * y[n - k])
|
||||
// => t = (1 - y[n - k]) / (1 - 2 * y[n - k])
|
||||
// = b
|
||||
let b_const = (SecureField::one() - y[n - k]) / (SecureField::one() - y[n - k].double());
|
||||
|
||||
// We get that `r(t) = f(t) * eq(t, y[n - k]) * a`.
|
||||
let r_at_0 = f_at_0 * eq(&[SecureField::zero()], &[y[n - k]]) * a_const;
|
||||
let r_at_1 = claim - r_at_0;
|
||||
let r_at_2 = f_at_2 * eq(&[BaseField::from(2).into()], &[y[n - k]]) * a_const;
|
||||
let r_at_b = SecureField::zero();
|
||||
|
||||
// Interpolate.
|
||||
UnivariatePoly::interpolate_lagrange(
|
||||
&[
|
||||
SecureField::zero(),
|
||||
SecureField::one(),
|
||||
SecureField::from(BaseField::from(2)),
|
||||
b_const,
|
||||
],
|
||||
&[r_at_0, r_at_1, r_at_2, r_at_b],
|
||||
)
|
||||
}
|
||||
357
Stwo_wrapper/crates/prover/src/core/lookups/gkr_verifier.rs
Normal file
357
Stwo_wrapper/crates/prover/src/core/lookups/gkr_verifier.rs
Normal file
@ -0,0 +1,357 @@
|
||||
//! GKR batch verifier for Grand Product and LogUp lookup arguments.
|
||||
use thiserror::Error;
|
||||
|
||||
use super::sumcheck::{SumcheckError, SumcheckProof};
|
||||
use super::utils::{eq, fold_mle_evals, random_linear_combination};
|
||||
use crate::core::channel::Channel;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::lookups::sumcheck;
|
||||
use crate::core::lookups::utils::Fraction;
|
||||
|
||||
/// Partially verifies a batch GKR proof.
|
||||
///
|
||||
/// On successful verification the function returns a [`GkrArtifact`] which stores the out-of-domain
|
||||
/// point and claimed evaluations in the input layer columns for each instance at the OOD point.
|
||||
/// These claimed evaluations are not checked in this function - hence partial verification.
|
||||
pub fn partially_verify_batch(
|
||||
gate_by_instance: Vec<Gate>,
|
||||
proof: &GkrBatchProof,
|
||||
channel: &mut impl Channel,
|
||||
) -> Result<GkrArtifact, GkrError> {
|
||||
let GkrBatchProof {
|
||||
sumcheck_proofs,
|
||||
layer_masks_by_instance,
|
||||
output_claims_by_instance,
|
||||
} = proof;
|
||||
|
||||
if layer_masks_by_instance.len() != output_claims_by_instance.len() {
|
||||
return Err(GkrError::MalformedProof);
|
||||
}
|
||||
|
||||
let n_instances = layer_masks_by_instance.len();
|
||||
let instance_n_layers = |instance: usize| layer_masks_by_instance[instance].len();
|
||||
let n_layers = (0..n_instances).map(instance_n_layers).max().unwrap();
|
||||
|
||||
if n_layers != sumcheck_proofs.len() {
|
||||
return Err(GkrError::MalformedProof);
|
||||
}
|
||||
|
||||
if gate_by_instance.len() != n_instances {
|
||||
return Err(GkrError::NumInstancesMismatch {
|
||||
given: gate_by_instance.len(),
|
||||
proof: n_instances,
|
||||
});
|
||||
}
|
||||
|
||||
let mut ood_point = vec![];
|
||||
let mut claims_to_verify_by_instance = vec![None; n_instances];
|
||||
|
||||
for (layer, sumcheck_proof) in sumcheck_proofs.iter().enumerate() {
|
||||
let n_remaining_layers = n_layers - layer;
|
||||
|
||||
// Check for output layers.
|
||||
for instance in 0..n_instances {
|
||||
if instance_n_layers(instance) == n_remaining_layers {
|
||||
let output_claims = output_claims_by_instance[instance].clone();
|
||||
claims_to_verify_by_instance[instance] = Some(output_claims);
|
||||
}
|
||||
}
|
||||
|
||||
// Seed the channel with layer claims.
|
||||
for claims_to_verify in claims_to_verify_by_instance.iter().flatten() {
|
||||
channel.mix_felts(claims_to_verify);
|
||||
}
|
||||
|
||||
let sumcheck_alpha = channel.draw_felt();
|
||||
let instance_lambda = channel.draw_felt();
|
||||
|
||||
let mut sumcheck_claims = Vec::new();
|
||||
let mut sumcheck_instances = Vec::new();
|
||||
|
||||
// Prepare the sumcheck claim.
|
||||
for (instance, claims_to_verify) in claims_to_verify_by_instance.iter().enumerate() {
|
||||
if let Some(claims_to_verify) = claims_to_verify {
|
||||
let n_unused_variables = n_layers - instance_n_layers(instance);
|
||||
let doubling_factor = BaseField::from(1 << n_unused_variables);
|
||||
let claim =
|
||||
random_linear_combination(claims_to_verify, instance_lambda) * doubling_factor;
|
||||
sumcheck_claims.push(claim);
|
||||
sumcheck_instances.push(instance);
|
||||
}
|
||||
}
|
||||
|
||||
let sumcheck_claim = random_linear_combination(&sumcheck_claims, sumcheck_alpha);
|
||||
let (sumcheck_ood_point, sumcheck_eval) =
|
||||
sumcheck::partially_verify(sumcheck_claim, sumcheck_proof, channel)
|
||||
.map_err(|source| GkrError::InvalidSumcheck { layer, source })?;
|
||||
|
||||
let mut layer_evals = Vec::new();
|
||||
|
||||
// Evaluate the circuit locally at sumcheck OOD point.
|
||||
for &instance in &sumcheck_instances {
|
||||
let n_unused = n_layers - instance_n_layers(instance);
|
||||
let mask = &layer_masks_by_instance[instance][layer - n_unused];
|
||||
let gate = &gate_by_instance[instance];
|
||||
let gate_output = gate.eval(mask).map_err(|InvalidNumMaskColumnsError| {
|
||||
let instance_layer = instance_n_layers(layer) - n_remaining_layers;
|
||||
GkrError::InvalidMask {
|
||||
instance,
|
||||
instance_layer,
|
||||
}
|
||||
})?;
|
||||
// TODO: Consider simplifying the code by just using the same eq eval for all instances
|
||||
// regardless of size.
|
||||
let eq_eval = eq(&ood_point[n_unused..], &sumcheck_ood_point[n_unused..]);
|
||||
layer_evals.push(eq_eval * random_linear_combination(&gate_output, instance_lambda));
|
||||
}
|
||||
|
||||
let layer_eval = random_linear_combination(&layer_evals, sumcheck_alpha);
|
||||
|
||||
if sumcheck_eval != layer_eval {
|
||||
return Err(GkrError::CircuitCheckFailure {
|
||||
claim: sumcheck_eval,
|
||||
output: layer_eval,
|
||||
layer,
|
||||
});
|
||||
}
|
||||
|
||||
// Seed the channel with the layer masks.
|
||||
for &instance in &sumcheck_instances {
|
||||
let n_unused = n_layers - instance_n_layers(instance);
|
||||
let mask = &layer_masks_by_instance[instance][layer - n_unused];
|
||||
channel.mix_felts(mask.columns().flatten());
|
||||
}
|
||||
|
||||
// Set the OOD evaluation point for layer above.
|
||||
let challenge = channel.draw_felt();
|
||||
ood_point = sumcheck_ood_point;
|
||||
ood_point.push(challenge);
|
||||
|
||||
// Set the claims to verify in the layer above.
|
||||
for instance in sumcheck_instances {
|
||||
let n_unused = n_layers - instance_n_layers(instance);
|
||||
let mask = &layer_masks_by_instance[instance][layer - n_unused];
|
||||
claims_to_verify_by_instance[instance] = Some(mask.reduce_at_point(challenge));
|
||||
}
|
||||
}
|
||||
|
||||
let claims_to_verify_by_instance = claims_to_verify_by_instance
|
||||
.into_iter()
|
||||
.map(Option::unwrap)
|
||||
.collect();
|
||||
|
||||
Ok(GkrArtifact {
|
||||
ood_point,
|
||||
claims_to_verify_by_instance,
|
||||
n_variables_by_instance: (0..n_instances).map(instance_n_layers).collect(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Batch GKR proof.
|
||||
pub struct GkrBatchProof {
|
||||
/// Sum-check proof for each layer.
|
||||
pub sumcheck_proofs: Vec<SumcheckProof>,
|
||||
/// Mask for each layer for each instance.
|
||||
pub layer_masks_by_instance: Vec<Vec<GkrMask>>,
|
||||
/// Column circuit outputs for each instance.
|
||||
pub output_claims_by_instance: Vec<Vec<SecureField>>,
|
||||
}
|
||||
|
||||
/// Values of interest obtained from the execution of the GKR protocol.
|
||||
pub struct GkrArtifact {
|
||||
/// Out-of-domain (OOD) point for evaluating columns in the input layer.
|
||||
pub ood_point: Vec<SecureField>,
|
||||
/// The claimed evaluation at `ood_point` for each column in the input layer of each instance.
|
||||
pub claims_to_verify_by_instance: Vec<Vec<SecureField>>,
|
||||
/// The number of variables that interpolate the input layer of each instance.
|
||||
pub n_variables_by_instance: Vec<usize>,
|
||||
}
|
||||
|
||||
/// Defines how a circuit operates locally on two input rows to produce a single output row.
|
||||
/// This local 2-to-1 constraint is what gives the whole circuit its "binary tree" structure.
|
||||
///
|
||||
/// Binary tree structured circuits have a highly regular wiring pattern that fit the structure of
|
||||
/// the circuits defined in [Thaler13] which allow for efficient linear time (linear in size of the
|
||||
/// circuit) GKR prover implementations.
|
||||
///
|
||||
/// [Thaler13]: https://eprint.iacr.org/2013/351.pdf
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum Gate {
|
||||
LogUp,
|
||||
GrandProduct,
|
||||
}
|
||||
|
||||
impl Gate {
|
||||
/// Returns the output after applying the gate to the mask.
|
||||
fn eval(&self, mask: &GkrMask) -> Result<Vec<SecureField>, InvalidNumMaskColumnsError> {
|
||||
Ok(match self {
|
||||
Self::LogUp => {
|
||||
if mask.columns().len() != 2 {
|
||||
return Err(InvalidNumMaskColumnsError);
|
||||
}
|
||||
|
||||
let [numerator_a, numerator_b] = mask.columns()[0];
|
||||
let [denominator_a, denominator_b] = mask.columns()[1];
|
||||
|
||||
let a = Fraction::new(numerator_a, denominator_a);
|
||||
let b = Fraction::new(numerator_b, denominator_b);
|
||||
let res = a + b;
|
||||
|
||||
vec![res.numerator, res.denominator]
|
||||
}
|
||||
Self::GrandProduct => {
|
||||
if mask.columns().len() != 1 {
|
||||
return Err(InvalidNumMaskColumnsError);
|
||||
}
|
||||
|
||||
let [a, b] = mask.columns()[0];
|
||||
vec![a * b]
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Mask has an invalid number of columns
|
||||
#[derive(Debug)]
|
||||
struct InvalidNumMaskColumnsError;
|
||||
|
||||
/// Stores two evaluations of each column in a GKR layer.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GkrMask {
|
||||
columns: Vec<[SecureField; 2]>,
|
||||
}
|
||||
|
||||
impl GkrMask {
|
||||
pub fn new(columns: Vec<[SecureField; 2]>) -> Self {
|
||||
Self { columns }
|
||||
}
|
||||
|
||||
pub fn to_rows(&self) -> [Vec<SecureField>; 2] {
|
||||
self.columns.iter().map(|[a, b]| (a, b)).unzip().into()
|
||||
}
|
||||
|
||||
pub fn columns(&self) -> &[[SecureField; 2]] {
|
||||
&self.columns
|
||||
}
|
||||
|
||||
/// Returns all `p_i(x)` where `p_i` interpolates column `i` of the mask on `{0, 1}`.
|
||||
pub fn reduce_at_point(&self, x: SecureField) -> Vec<SecureField> {
|
||||
self.columns
|
||||
.iter()
|
||||
.map(|&[v0, v1]| fold_mle_evals(x, v0, v1))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Error encountered during GKR protocol verification.
|
||||
#[derive(Error, Debug)]
|
||||
pub enum GkrError {
|
||||
/// The proof is malformed.
|
||||
#[error("proof data is invalid")]
|
||||
MalformedProof,
|
||||
/// Mask has an invalid number of columns.
|
||||
#[error("mask in layer {instance_layer} of instance {instance} is invalid")]
|
||||
InvalidMask {
|
||||
instance: usize,
|
||||
/// Layer of the instance (but not necessarily the batch).
|
||||
instance_layer: LayerIndex,
|
||||
},
|
||||
/// There is a mismatch between the number of instances in the proof and the number of
|
||||
/// instances passed for verification.
|
||||
#[error("provided an invalid number of instances (given {given}, proof expects {proof})")]
|
||||
NumInstancesMismatch { given: usize, proof: usize },
|
||||
/// There was an error with one of the sumcheck proofs.
|
||||
#[error("sum-check invalid in layer {layer}: {source}")]
|
||||
InvalidSumcheck {
|
||||
layer: LayerIndex,
|
||||
source: SumcheckError,
|
||||
},
|
||||
/// The circuit polynomial the verifier evaluated doesn't match claim from sumcheck.
|
||||
#[error("circuit check failed in layer {layer} (calculated {output}, claim {claim})")]
|
||||
CircuitCheckFailure {
|
||||
claim: SecureField,
|
||||
output: SecureField,
|
||||
layer: LayerIndex,
|
||||
},
|
||||
}
|
||||
|
||||
/// GKR layer index where 0 corresponds to the output layer.
|
||||
pub type LayerIndex = usize;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{partially_verify_batch, Gate, GkrArtifact, GkrError};
|
||||
use crate::core::backend::CpuBackend;
|
||||
use crate::core::channel::Channel;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::lookups::gkr_prover::{prove_batch, Layer};
|
||||
use crate::core::lookups::mle::Mle;
|
||||
use crate::core::test_utils::test_channel;
|
||||
|
||||
#[test]
|
||||
fn prove_batch_works() -> Result<(), GkrError> {
|
||||
const LOG_N: usize = 5;
|
||||
let mut channel = test_channel();
|
||||
let col0 = Mle::<CpuBackend, SecureField>::new(channel.draw_felts(1 << LOG_N));
|
||||
let col1 = Mle::<CpuBackend, SecureField>::new(channel.draw_felts(1 << LOG_N));
|
||||
let product0 = col0.iter().product::<SecureField>();
|
||||
let product1 = col1.iter().product::<SecureField>();
|
||||
let input_layers = vec![
|
||||
Layer::GrandProduct(col0.clone()),
|
||||
Layer::GrandProduct(col1.clone()),
|
||||
];
|
||||
let (proof, _) = prove_batch(&mut test_channel(), input_layers);
|
||||
|
||||
let GkrArtifact {
|
||||
ood_point,
|
||||
claims_to_verify_by_instance,
|
||||
n_variables_by_instance,
|
||||
} = partially_verify_batch(vec![Gate::GrandProduct; 2], &proof, &mut test_channel())?;
|
||||
|
||||
assert_eq!(n_variables_by_instance, [LOG_N, LOG_N]);
|
||||
assert_eq!(proof.output_claims_by_instance.len(), 2);
|
||||
assert_eq!(claims_to_verify_by_instance.len(), 2);
|
||||
assert_eq!(proof.output_claims_by_instance[0], &[product0]);
|
||||
assert_eq!(proof.output_claims_by_instance[1], &[product1]);
|
||||
let claim0 = &claims_to_verify_by_instance[0];
|
||||
let claim1 = &claims_to_verify_by_instance[1];
|
||||
assert_eq!(claim0, &[col0.eval_at_point(&ood_point)]);
|
||||
assert_eq!(claim1, &[col1.eval_at_point(&ood_point)]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prove_batch_with_different_sizes_works() -> Result<(), GkrError> {
|
||||
const LOG_N0: usize = 5;
|
||||
const LOG_N1: usize = 7;
|
||||
let mut channel = test_channel();
|
||||
let col0 = Mle::<CpuBackend, SecureField>::new(channel.draw_felts(1 << LOG_N0));
|
||||
let col1 = Mle::<CpuBackend, SecureField>::new(channel.draw_felts(1 << LOG_N1));
|
||||
let product0 = col0.iter().product::<SecureField>();
|
||||
let product1 = col1.iter().product::<SecureField>();
|
||||
let input_layers = vec![
|
||||
Layer::GrandProduct(col0.clone()),
|
||||
Layer::GrandProduct(col1.clone()),
|
||||
];
|
||||
let (proof, _) = prove_batch(&mut test_channel(), input_layers);
|
||||
|
||||
let GkrArtifact {
|
||||
ood_point,
|
||||
claims_to_verify_by_instance,
|
||||
n_variables_by_instance,
|
||||
} = partially_verify_batch(vec![Gate::GrandProduct; 2], &proof, &mut test_channel())?;
|
||||
|
||||
assert_eq!(n_variables_by_instance, [LOG_N0, LOG_N1]);
|
||||
assert_eq!(proof.output_claims_by_instance.len(), 2);
|
||||
assert_eq!(claims_to_verify_by_instance.len(), 2);
|
||||
assert_eq!(proof.output_claims_by_instance[0], &[product0]);
|
||||
assert_eq!(proof.output_claims_by_instance[1], &[product1]);
|
||||
let claim0 = &claims_to_verify_by_instance[0];
|
||||
let claim1 = &claims_to_verify_by_instance[1];
|
||||
let n_vars = ood_point.len();
|
||||
assert_eq!(claim0, &[col0.eval_at_point(&ood_point[n_vars - LOG_N0..])]);
|
||||
assert_eq!(claim1, &[col1.eval_at_point(&ood_point[n_vars - LOG_N1..])]);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
106
Stwo_wrapper/crates/prover/src/core/lookups/mle.rs
Normal file
106
Stwo_wrapper/crates/prover/src/core/lookups/mle.rs
Normal file
@ -0,0 +1,106 @@
|
||||
use std::ops::{Deref, DerefMut};
|
||||
|
||||
use educe::Educe;
|
||||
|
||||
use crate::core::backend::{Col, Column, ColumnOps};
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::fields::Field;
|
||||
|
||||
pub trait MleOps<F: Field>: ColumnOps<F> + Sized {
|
||||
/// Returns a transformed [`Mle`] where the first variable is fixed to `assignment`.
|
||||
fn fix_first_variable(mle: Mle<Self, F>, assignment: SecureField) -> Mle<Self, SecureField>
|
||||
where
|
||||
Self: MleOps<SecureField>;
|
||||
}
|
||||
|
||||
/// Multilinear Extension stored as evaluations of a multilinear polynomial over the boolean
|
||||
/// hypercube in bit-reversed order.
|
||||
#[derive(Educe)]
|
||||
#[educe(Debug, Clone)]
|
||||
pub struct Mle<B: ColumnOps<F>, F: Field> {
|
||||
evals: Col<B, F>,
|
||||
}
|
||||
|
||||
impl<B: MleOps<F>, F: Field> Mle<B, F> {
|
||||
/// Creates a [`Mle`] from evaluations of a multilinear polynomial on the boolean hypercube.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the number of evaluations is not a power of two.
|
||||
pub fn new(evals: Col<B, F>) -> Self {
|
||||
assert!(evals.len().is_power_of_two());
|
||||
Self { evals }
|
||||
}
|
||||
|
||||
pub fn into_evals(self) -> Col<B, F> {
|
||||
self.evals
|
||||
}
|
||||
|
||||
/// Returns a transformed polynomial where the first variable is fixed to `assignment`.
|
||||
pub fn fix_first_variable(self, assignment: SecureField) -> Mle<B, SecureField>
|
||||
where
|
||||
B: MleOps<SecureField>,
|
||||
{
|
||||
B::fix_first_variable(self, assignment)
|
||||
}
|
||||
|
||||
/// Returns the number of variables in the polynomial.
|
||||
pub fn n_variables(&self) -> usize {
|
||||
self.evals.len().ilog2() as usize
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: ColumnOps<F>, F: Field> Deref for Mle<B, F> {
|
||||
type Target = Col<B, F>;
|
||||
|
||||
fn deref(&self) -> &Col<B, F> {
|
||||
&self.evals
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: ColumnOps<F>, F: Field> DerefMut for Mle<B, F> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.evals
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::{Mle, MleOps};
|
||||
use crate::core::backend::Column;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::fields::{ExtensionOf, Field};
|
||||
|
||||
impl<B, F> Mle<B, F>
|
||||
where
|
||||
F: Field,
|
||||
SecureField: ExtensionOf<F>,
|
||||
B: MleOps<F>,
|
||||
{
|
||||
/// Evaluates the multilinear polynomial at `point`.
|
||||
pub(crate) fn eval_at_point(&self, point: &[SecureField]) -> SecureField {
|
||||
pub fn eval(mle_evals: &[SecureField], p: &[SecureField]) -> SecureField {
|
||||
match p {
|
||||
[] => mle_evals[0],
|
||||
&[p_i, ref p @ ..] => {
|
||||
let (lhs, rhs) = mle_evals.split_at(mle_evals.len() / 2);
|
||||
let lhs_eval = eval(lhs, p);
|
||||
let rhs_eval = eval(rhs, p);
|
||||
// Equivalent to `eq(0, p_i) * lhs_eval + eq(1, p_i) * rhs_eval`.
|
||||
p_i * (rhs_eval - lhs_eval) + lhs_eval
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mle_evals = self
|
||||
.clone()
|
||||
.into_evals()
|
||||
.to_cpu()
|
||||
.into_iter()
|
||||
.map(|v| v.into())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
eval(&mle_evals, point)
|
||||
}
|
||||
}
|
||||
}
|
||||
5
Stwo_wrapper/crates/prover/src/core/lookups/mod.rs
Normal file
5
Stwo_wrapper/crates/prover/src/core/lookups/mod.rs
Normal file
@ -0,0 +1,5 @@
|
||||
pub mod gkr_prover;
|
||||
pub mod gkr_verifier;
|
||||
pub mod mle;
|
||||
pub mod sumcheck;
|
||||
pub mod utils;
|
||||
292
Stwo_wrapper/crates/prover/src/core/lookups/sumcheck.rs
Normal file
292
Stwo_wrapper/crates/prover/src/core/lookups/sumcheck.rs
Normal file
@ -0,0 +1,292 @@
|
||||
//! Sum-check protocol that proves and verifies claims about `sum_x g(x)` for all x in `{0, 1}^n`.
|
||||
//!
|
||||
//! [`MultivariatePolyOracle`] provides methods for evaluating sums and making transformations on
|
||||
//! `g` in the context of the protocol. It is intended to be used in conjunction with
|
||||
//! [`prove_batch()`] to generate proofs.
|
||||
|
||||
use std::iter::zip;
|
||||
|
||||
use itertools::Itertools;
|
||||
use num_traits::{One, Zero};
|
||||
use thiserror::Error;
|
||||
|
||||
use super::utils::UnivariatePoly;
|
||||
use crate::core::channel::Channel;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
|
||||
/// Something that can be seen as a multivariate polynomial `g(x_0, ..., x_{n-1})`.
|
||||
pub trait MultivariatePolyOracle: Sized {
|
||||
/// Returns the number of variables in `g`.
|
||||
fn n_variables(&self) -> usize;
|
||||
|
||||
/// Computes the sum of `g(x_0, x_1, ..., x_{n-1})` over all `(x_1, ..., x_{n-1})` in
|
||||
/// `{0, 1}^(n-1)`, effectively reducing the sum over `g` to a univariate polynomial in `x_0`.
|
||||
///
|
||||
/// `claim` equals the claimed sum of `g(x_0, x_2, ..., x_{n-1})` over all `(x_0, ..., x_{n-1})`
|
||||
/// in `{0, 1}^n`. Knowing the claim can help optimize the implementation: Let `f` denote the
|
||||
/// univariate polynomial we want to return. Note that `claim = f(0) + f(1)` so knowing `claim`
|
||||
/// and either `f(0)` or `f(1)` allows determining the other.
|
||||
fn sum_as_poly_in_first_variable(&self, claim: SecureField) -> UnivariatePoly<SecureField>;
|
||||
|
||||
/// Returns a transformed oracle where the first variable of `g` is fixed to `challenge`.
|
||||
///
|
||||
/// The returned oracle represents the multivariate polynomial `g'`, defined as
|
||||
/// `g'(x_1, ..., x_{n-1}) = g(challenge, x_1, ..., x_{n-1})`.
|
||||
fn fix_first_variable(self, challenge: SecureField) -> Self;
|
||||
}
|
||||
|
||||
/// Performs sum-check on a random linear combinations of multiple multivariate polynomials.
|
||||
///
|
||||
/// Let the multivariate polynomials be `g_0, ..., g_{n-1}`. A single sum-check is performed on
|
||||
/// multivariate polynomial `h = g_0 + lambda * g_1 + ... + lambda^(n-1) * g_{n-1}`. The `g_i`s do
|
||||
/// not need to have the same number of variables. `g_i`s with less variables are folded in the
|
||||
/// latest possible round of the protocol. For instance with `g_0(x, y, z)` and `g_1(x, y)`
|
||||
/// sum-check is performed on `h(x, y, z) = g_0(x, y, z) + lambda * g_1(y, z)`. Claim `c_i` should
|
||||
/// equal the claimed sum of `g_i(x_0, ..., x_{j-1})` over all `(x_0, ..., x_{j-1})` in `{0, 1}^j`.
|
||||
///
|
||||
/// The degree of each `g_i` should not exceed [`MAX_DEGREE`] in any variable. The sum-check proof
|
||||
/// of `h`, list of challenges (variable assignment) and the constant oracles (i.e. the `g_i` with
|
||||
/// all variables fixed to the their corresponding challenges) are returned.
|
||||
///
|
||||
/// Output is of the form: `(proof, variable_assignment, constant_poly_oracles, claimed_evals)`
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if:
|
||||
/// - No multivariate polynomials are provided.
|
||||
/// - There aren't the same number of multivariate polynomials and claims.
|
||||
/// - The degree of any multivariate polynomial exceeds [`MAX_DEGREE`] in any variable.
|
||||
/// - The round polynomials are inconsistent with their corresponding claimed sum on `0` and `1`.
|
||||
// TODO: Consider returning constant oracles as separate type.
|
||||
pub fn prove_batch<O: MultivariatePolyOracle>(
|
||||
mut claims: Vec<SecureField>,
|
||||
mut multivariate_polys: Vec<O>,
|
||||
lambda: SecureField,
|
||||
channel: &mut impl Channel,
|
||||
) -> (SumcheckProof, Vec<SecureField>, Vec<O>, Vec<SecureField>) {
|
||||
let n_variables = multivariate_polys.iter().map(O::n_variables).max().unwrap();
|
||||
assert_eq!(claims.len(), multivariate_polys.len());
|
||||
|
||||
let mut round_polys = Vec::new();
|
||||
let mut assignment = Vec::new();
|
||||
|
||||
// Update the claims for the sum over `h`'s hypercube.
|
||||
for (claim, multivariate_poly) in zip(&mut claims, &multivariate_polys) {
|
||||
let n_unused_variables = n_variables - multivariate_poly.n_variables();
|
||||
*claim *= BaseField::from(1 << n_unused_variables);
|
||||
}
|
||||
|
||||
// Prove sum-check rounds
|
||||
for round in 0..n_variables {
|
||||
let n_remaining_rounds = n_variables - round;
|
||||
|
||||
let this_round_polys = zip(&multivariate_polys, &claims)
|
||||
.enumerate()
|
||||
.map(|(i, (multivariate_poly, &claim))| {
|
||||
let round_poly = if n_remaining_rounds == multivariate_poly.n_variables() {
|
||||
multivariate_poly.sum_as_poly_in_first_variable(claim)
|
||||
} else {
|
||||
(claim / BaseField::from(2)).into()
|
||||
};
|
||||
|
||||
let eval_at_0 = round_poly.eval_at_point(SecureField::zero());
|
||||
let eval_at_1 = round_poly.eval_at_point(SecureField::one());
|
||||
assert_eq!(eval_at_0 + eval_at_1, claim, "i={i}, round={round}");
|
||||
assert!(round_poly.degree() <= MAX_DEGREE, "i={i}, round={round}");
|
||||
|
||||
round_poly
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
let round_poly = random_linear_combination(&this_round_polys, lambda);
|
||||
|
||||
channel.mix_felts(&round_poly);
|
||||
|
||||
let challenge = channel.draw_felt();
|
||||
|
||||
claims = this_round_polys
|
||||
.iter()
|
||||
.map(|round_poly| round_poly.eval_at_point(challenge))
|
||||
.collect();
|
||||
|
||||
multivariate_polys = multivariate_polys
|
||||
.into_iter()
|
||||
.map(|multivariate_poly| {
|
||||
if n_remaining_rounds != multivariate_poly.n_variables() {
|
||||
return multivariate_poly;
|
||||
}
|
||||
|
||||
multivariate_poly.fix_first_variable(challenge)
|
||||
})
|
||||
.collect();
|
||||
|
||||
round_polys.push(round_poly);
|
||||
assignment.push(challenge);
|
||||
}
|
||||
|
||||
let proof = SumcheckProof { round_polys };
|
||||
|
||||
(proof, assignment, multivariate_polys, claims)
|
||||
}
|
||||
|
||||
/// Returns `p_0 + alpha * p_1 + ... + alpha^(n-1) * p_{n-1}`.
|
||||
fn random_linear_combination(
|
||||
polys: &[UnivariatePoly<SecureField>],
|
||||
alpha: SecureField,
|
||||
) -> UnivariatePoly<SecureField> {
|
||||
polys
|
||||
.iter()
|
||||
.rfold(Zero::zero(), |acc, poly| acc * alpha + poly.clone())
|
||||
}
|
||||
|
||||
/// Partially verifies a sum-check proof.
|
||||
///
|
||||
/// Only "partial" since it does not fully verify the prover's claimed evaluation on the variable
|
||||
/// assignment but checks if the sum of the round polynomials evaluated on `0` and `1` matches the
|
||||
/// claim for each round. If the proof passes these checks, the variable assignment and the prover's
|
||||
/// claimed evaluation are returned for the caller to validate otherwise an [`Err`] is returned.
|
||||
///
|
||||
/// Output is of the form `(variable_assignment, claimed_eval)`.
|
||||
pub fn partially_verify(
|
||||
mut claim: SecureField,
|
||||
proof: &SumcheckProof,
|
||||
channel: &mut impl Channel,
|
||||
) -> Result<(Vec<SecureField>, SecureField), SumcheckError> {
|
||||
let mut assignment = Vec::new();
|
||||
|
||||
for (round, round_poly) in proof.round_polys.iter().enumerate() {
|
||||
if round_poly.degree() > MAX_DEGREE {
|
||||
return Err(SumcheckError::DegreeInvalid { round });
|
||||
}
|
||||
|
||||
// TODO: optimize this by sending one less coefficient, and computing it from the
|
||||
// claim, instead of checking the claim. (Can also be done by quotienting).
|
||||
let sum = round_poly.eval_at_point(Zero::zero()) + round_poly.eval_at_point(One::one());
|
||||
|
||||
if claim != sum {
|
||||
return Err(SumcheckError::SumInvalid { claim, sum, round });
|
||||
}
|
||||
|
||||
channel.mix_felts(round_poly);
|
||||
let challenge = channel.draw_felt();
|
||||
claim = round_poly.eval_at_point(challenge);
|
||||
assignment.push(challenge);
|
||||
}
|
||||
|
||||
Ok((assignment, claim))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SumcheckProof {
|
||||
pub round_polys: Vec<UnivariatePoly<SecureField>>,
|
||||
}
|
||||
|
||||
/// Max degree of polynomials the verifier accepts in each round of the protocol.
|
||||
pub const MAX_DEGREE: usize = 3;
|
||||
|
||||
/// Sum-check protocol verification error.
|
||||
#[derive(Error, Debug)]
|
||||
pub enum SumcheckError {
|
||||
#[error("degree of the polynomial in round {round} is too high")]
|
||||
DegreeInvalid { round: RoundIndex },
|
||||
#[error("sum does not match the claim in round {round} (sum {sum}, claim {claim})")]
|
||||
SumInvalid {
|
||||
claim: SecureField,
|
||||
sum: SecureField,
|
||||
round: RoundIndex,
|
||||
},
|
||||
}
|
||||
|
||||
/// Sum-check round index where 0 corresponds to the first round.
|
||||
pub type RoundIndex = usize;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use num_traits::One;
|
||||
|
||||
use crate::core::backend::CpuBackend;
|
||||
use crate::core::channel::{Blake2sChannel, Channel};
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::fields::Field;
|
||||
use crate::core::lookups::mle::Mle;
|
||||
use crate::core::lookups::sumcheck::{partially_verify, prove_batch};
|
||||
|
||||
#[test]
|
||||
fn sumcheck_works() {
|
||||
let values = test_channel().draw_felts(32);
|
||||
let claim = values.iter().sum();
|
||||
let mle = Mle::<CpuBackend, SecureField>::new(values);
|
||||
let lambda = SecureField::one();
|
||||
let (proof, ..) = prove_batch(vec![claim], vec![mle.clone()], lambda, &mut test_channel());
|
||||
|
||||
let (assignment, eval) = partially_verify(claim, &proof, &mut test_channel()).unwrap();
|
||||
|
||||
assert_eq!(eval, mle.eval_at_point(&assignment));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn batch_sumcheck_works() {
|
||||
let mut channel = test_channel();
|
||||
let values0 = channel.draw_felts(32);
|
||||
let values1 = channel.draw_felts(32);
|
||||
let claim0 = values0.iter().sum();
|
||||
let claim1 = values1.iter().sum();
|
||||
let mle0 = Mle::<CpuBackend, SecureField>::new(values0.clone());
|
||||
let mle1 = Mle::<CpuBackend, SecureField>::new(values1.clone());
|
||||
let lambda = channel.draw_felt();
|
||||
let claims = vec![claim0, claim1];
|
||||
let mles = vec![mle0.clone(), mle1.clone()];
|
||||
let (proof, ..) = prove_batch(claims, mles, lambda, &mut test_channel());
|
||||
|
||||
let claim = claim0 + lambda * claim1;
|
||||
let (assignment, eval) = partially_verify(claim, &proof, &mut test_channel()).unwrap();
|
||||
|
||||
let eval0 = mle0.eval_at_point(&assignment);
|
||||
let eval1 = mle1.eval_at_point(&assignment);
|
||||
assert_eq!(eval, eval0 + lambda * eval1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn batch_sumcheck_with_different_n_variables() {
|
||||
let mut channel = test_channel();
|
||||
let values0 = channel.draw_felts(64);
|
||||
let values1 = channel.draw_felts(32);
|
||||
let claim0 = values0.iter().sum();
|
||||
let claim1 = values1.iter().sum();
|
||||
let mle0 = Mle::<CpuBackend, SecureField>::new(values0.clone());
|
||||
let mle1 = Mle::<CpuBackend, SecureField>::new(values1.clone());
|
||||
let lambda = channel.draw_felt();
|
||||
let claims = vec![claim0, claim1];
|
||||
let mles = vec![mle0.clone(), mle1.clone()];
|
||||
let (proof, ..) = prove_batch(claims, mles, lambda, &mut test_channel());
|
||||
|
||||
let claim = claim0 + lambda * claim1.double();
|
||||
let (assignment, eval) = partially_verify(claim, &proof, &mut test_channel()).unwrap();
|
||||
|
||||
let eval0 = mle0.eval_at_point(&assignment);
|
||||
let eval1 = mle1.eval_at_point(&assignment[1..]);
|
||||
assert_eq!(eval, eval0 + lambda * eval1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_sumcheck_proof_fails() {
|
||||
let values = test_channel().draw_felts(8);
|
||||
let claim = values.iter().sum::<SecureField>();
|
||||
let lambda = SecureField::one();
|
||||
// Compromise the first value.
|
||||
let mut invalid_values = values;
|
||||
invalid_values[0] += SecureField::one();
|
||||
let invalid_claim = vec![invalid_values.iter().sum::<SecureField>()];
|
||||
let invalid_mle = vec![Mle::<CpuBackend, SecureField>::new(invalid_values.clone())];
|
||||
let (invalid_proof, ..) =
|
||||
prove_batch(invalid_claim, invalid_mle, lambda, &mut test_channel());
|
||||
|
||||
assert!(partially_verify(claim, &invalid_proof, &mut test_channel()).is_err());
|
||||
}
|
||||
|
||||
fn test_channel() -> Blake2sChannel {
|
||||
Blake2sChannel::default()
|
||||
}
|
||||
}
|
||||
356
Stwo_wrapper/crates/prover/src/core/lookups/utils.rs
Normal file
356
Stwo_wrapper/crates/prover/src/core/lookups/utils.rs
Normal file
@ -0,0 +1,356 @@
|
||||
use std::iter::{zip, Sum};
|
||||
use std::ops::{Add, Deref, Mul, Neg, Sub};
|
||||
|
||||
use num_traits::{One, Zero};
|
||||
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::fields::{ExtensionOf, Field};
|
||||
|
||||
/// Univariate polynomial stored as coefficients in the monomial basis.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct UnivariatePoly<F: Field>(Vec<F>);
|
||||
|
||||
impl<F: Field> UnivariatePoly<F> {
|
||||
pub fn new(coeffs: Vec<F>) -> Self {
|
||||
let mut polynomial = Self(coeffs);
|
||||
polynomial.truncate_leading_zeros();
|
||||
polynomial
|
||||
}
|
||||
|
||||
pub fn eval_at_point(&self, x: F) -> F {
|
||||
horner_eval(&self.0, x)
|
||||
}
|
||||
|
||||
// <https://en.wikibooks.org/wiki/Algorithm_Implementation/Mathematics/Polynomial_interpolation>
|
||||
pub fn interpolate_lagrange(xs: &[F], ys: &[F]) -> Self {
|
||||
assert_eq!(xs.len(), ys.len());
|
||||
|
||||
let mut coeffs = Self::zero();
|
||||
|
||||
for (i, (&xi, &yi)) in zip(xs, ys).enumerate() {
|
||||
let mut prod = yi;
|
||||
|
||||
for (j, &xj) in xs.iter().enumerate() {
|
||||
if i != j {
|
||||
prod /= xi - xj;
|
||||
}
|
||||
}
|
||||
|
||||
let mut term = Self::new(vec![prod]);
|
||||
|
||||
for (j, &xj) in xs.iter().enumerate() {
|
||||
if i != j {
|
||||
term = term * (Self::x() - Self::new(vec![xj]));
|
||||
}
|
||||
}
|
||||
|
||||
coeffs = coeffs + term;
|
||||
}
|
||||
|
||||
coeffs.truncate_leading_zeros();
|
||||
|
||||
coeffs
|
||||
}
|
||||
|
||||
pub fn degree(&self) -> usize {
|
||||
let mut coeffs = self.0.iter().rev();
|
||||
let _ = (&mut coeffs).take_while(|v| v.is_zero());
|
||||
coeffs.len().saturating_sub(1)
|
||||
}
|
||||
|
||||
fn x() -> Self {
|
||||
Self(vec![F::zero(), F::one()])
|
||||
}
|
||||
|
||||
fn truncate_leading_zeros(&mut self) {
|
||||
while self.0.last() == Some(&F::zero()) {
|
||||
self.0.pop();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> From<F> for UnivariatePoly<F> {
|
||||
fn from(value: F) -> Self {
|
||||
Self::new(vec![value])
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> Mul<F> for UnivariatePoly<F> {
|
||||
type Output = Self;
|
||||
|
||||
fn mul(mut self, rhs: F) -> Self {
|
||||
self.0.iter_mut().for_each(|coeff| *coeff *= rhs);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> Mul for UnivariatePoly<F> {
|
||||
type Output = Self;
|
||||
|
||||
fn mul(mut self, mut rhs: Self) -> Self {
|
||||
if self.is_zero() || rhs.is_zero() {
|
||||
return Self::zero();
|
||||
}
|
||||
|
||||
self.truncate_leading_zeros();
|
||||
rhs.truncate_leading_zeros();
|
||||
|
||||
let mut res = vec![F::zero(); self.0.len() + rhs.0.len() - 1];
|
||||
|
||||
for (i, coeff_a) in self.0.into_iter().enumerate() {
|
||||
for (j, &coeff_b) in rhs.0.iter().enumerate() {
|
||||
res[i + j] += coeff_a * coeff_b;
|
||||
}
|
||||
}
|
||||
|
||||
Self::new(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> Add for UnivariatePoly<F> {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: Self) -> Self {
|
||||
let n = self.0.len().max(rhs.0.len());
|
||||
let mut res = Vec::new();
|
||||
|
||||
for i in 0..n {
|
||||
res.push(match (self.0.get(i), rhs.0.get(i)) {
|
||||
(Some(&a), Some(&b)) => a + b,
|
||||
(Some(&a), None) | (None, Some(&a)) => a,
|
||||
_ => unreachable!(),
|
||||
})
|
||||
}
|
||||
|
||||
Self(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> Sub for UnivariatePoly<F> {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self {
|
||||
self + (-rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> Neg for UnivariatePoly<F> {
|
||||
type Output = Self;
|
||||
|
||||
fn neg(self) -> Self {
|
||||
Self(self.0.into_iter().map(|v| -v).collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> Zero for UnivariatePoly<F> {
|
||||
fn zero() -> Self {
|
||||
Self(vec![])
|
||||
}
|
||||
|
||||
fn is_zero(&self) -> bool {
|
||||
self.0.iter().all(F::is_zero)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> Deref for UnivariatePoly<F> {
|
||||
type Target = [F];
|
||||
|
||||
fn deref(&self) -> &[F] {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluates univariate polynomial using [Horner's method].
|
||||
///
|
||||
/// [Horner's method]: https://en.wikipedia.org/wiki/Horner%27s_method
|
||||
pub fn horner_eval<F: Field>(coeffs: &[F], x: F) -> F {
|
||||
coeffs
|
||||
.iter()
|
||||
.rfold(F::zero(), |acc, &coeff| acc * x + coeff)
|
||||
}
|
||||
|
||||
/// Returns `v_0 + alpha * v_1 + ... + alpha^(n-1) * v_{n-1}`.
|
||||
pub fn random_linear_combination(v: &[SecureField], alpha: SecureField) -> SecureField {
|
||||
horner_eval(v, alpha)
|
||||
}
|
||||
|
||||
/// Evaluates the lagrange kernel of the boolean hypercube.
|
||||
///
|
||||
/// The lagrange kernel of the boolean hypercube is a multilinear extension of the function that
|
||||
/// when given `x, y` in `{0, 1}^n` evaluates to 1 if `x = y`, and evaluates to 0 otherwise.
|
||||
pub fn eq<F: Field>(x: &[F], y: &[F]) -> F {
|
||||
assert_eq!(x.len(), y.len());
|
||||
zip(x, y)
|
||||
.map(|(&xi, &yi)| xi * yi + (F::one() - xi) * (F::one() - yi))
|
||||
.product()
|
||||
}
|
||||
|
||||
/// Computes `eq(0, assignment) * eval0 + eq(1, assignment) * eval1`.
|
||||
pub fn fold_mle_evals<F>(assignment: SecureField, eval0: F, eval1: F) -> SecureField
|
||||
where
|
||||
F: Field,
|
||||
SecureField: ExtensionOf<F>,
|
||||
{
|
||||
assignment * (eval1 - eval0) + eval0
|
||||
}
|
||||
|
||||
/// Projective fraction.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct Fraction<N, D> {
|
||||
pub numerator: N,
|
||||
pub denominator: D,
|
||||
}
|
||||
|
||||
impl<N, D> Fraction<N, D> {
|
||||
pub fn new(numerator: N, denominator: D) -> Self {
|
||||
Self {
|
||||
numerator,
|
||||
denominator,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<N, D: Add<Output = D> + Add<N, Output = D> + Mul<N, Output = D> + Mul<Output = D> + Copy> Add
|
||||
for Fraction<N, D>
|
||||
{
|
||||
type Output = Fraction<D, D>;
|
||||
|
||||
fn add(self, rhs: Self) -> Fraction<D, D> {
|
||||
Fraction {
|
||||
numerator: rhs.denominator * self.numerator + self.denominator * rhs.numerator,
|
||||
denominator: self.denominator * rhs.denominator,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: Zero, D: One + Zero> Zero for Fraction<N, D>
|
||||
where
|
||||
Self: Add<Output = Self>,
|
||||
{
|
||||
fn zero() -> Self {
|
||||
Self {
|
||||
numerator: N::zero(),
|
||||
denominator: D::one(),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_zero(&self) -> bool {
|
||||
self.numerator.is_zero() && !self.denominator.is_zero()
|
||||
}
|
||||
}
|
||||
|
||||
impl<N, D> Sum for Fraction<N, D>
|
||||
where
|
||||
Self: Zero,
|
||||
{
|
||||
fn sum<I: Iterator<Item = Self>>(mut iter: I) -> Self {
|
||||
let first = iter.next().unwrap_or_else(Self::zero);
|
||||
iter.fold(first, |a, b| a + b)
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the fraction `1 / x`
|
||||
pub struct Reciprocal<T> {
|
||||
x: T,
|
||||
}
|
||||
|
||||
impl<T> Reciprocal<T> {
|
||||
pub fn new(x: T) -> Self {
|
||||
Self { x }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Add<Output = T> + Mul<Output = T> + Copy> Add for Reciprocal<T> {
|
||||
type Output = Fraction<T, T>;
|
||||
|
||||
fn add(self, rhs: Self) -> Fraction<T, T> {
|
||||
// `1/a + 1/b = (a + b)/(a * b)`
|
||||
Fraction {
|
||||
numerator: self.x + rhs.x,
|
||||
denominator: self.x * rhs.x,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::iter::zip;
|
||||
|
||||
use num_traits::{One, Zero};
|
||||
|
||||
use super::{horner_eval, UnivariatePoly};
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::fields::FieldExpOps;
|
||||
use crate::core::lookups::utils::{eq, Fraction};
|
||||
|
||||
#[test]
|
||||
fn lagrange_interpolation_works() {
|
||||
let xs = [5, 1, 3, 9].map(BaseField::from);
|
||||
let ys = [1, 2, 3, 4].map(BaseField::from);
|
||||
|
||||
let poly = UnivariatePoly::interpolate_lagrange(&xs, &ys);
|
||||
|
||||
for (x, y) in zip(xs, ys) {
|
||||
assert_eq!(poly.eval_at_point(x), y, "mismatch for x={x}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn horner_eval_works() {
|
||||
let coeffs = [BaseField::from(9), BaseField::from(2), BaseField::from(3)];
|
||||
let x = BaseField::from(7);
|
||||
|
||||
let eval = horner_eval(&coeffs, x);
|
||||
|
||||
assert_eq!(eval, coeffs[0] + coeffs[1] * x + coeffs[2] * x.square());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn eq_identical_hypercube_points_returns_one() {
|
||||
let zero = SecureField::zero();
|
||||
let one = SecureField::one();
|
||||
let a = &[one, zero, one];
|
||||
|
||||
let eq_eval = eq(a, a);
|
||||
|
||||
assert_eq!(eq_eval, one);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn eq_different_hypercube_points_returns_zero() {
|
||||
let zero = SecureField::zero();
|
||||
let one = SecureField::one();
|
||||
let a = &[one, zero, one];
|
||||
let b = &[one, zero, zero];
|
||||
|
||||
let eq_eval = eq(a, b);
|
||||
|
||||
assert_eq!(eq_eval, zero);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn eq_different_size_points() {
|
||||
let zero = SecureField::zero();
|
||||
let one = SecureField::one();
|
||||
|
||||
eq(&[zero, one], &[zero]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fraction_addition_works() {
|
||||
let a = Fraction::new(BaseField::from(1), BaseField::from(3));
|
||||
let b = Fraction::new(BaseField::from(2), BaseField::from(6));
|
||||
|
||||
let Fraction {
|
||||
numerator,
|
||||
denominator,
|
||||
} = a + b;
|
||||
|
||||
assert_eq!(
|
||||
numerator / denominator,
|
||||
BaseField::from(2) / BaseField::from(3)
|
||||
);
|
||||
}
|
||||
}
|
||||
59
Stwo_wrapper/crates/prover/src/core/mod.rs
Normal file
59
Stwo_wrapper/crates/prover/src/core/mod.rs
Normal file
@ -0,0 +1,59 @@
|
||||
use std::ops::{Deref, DerefMut};
|
||||
|
||||
pub mod air;
|
||||
pub mod backend;
|
||||
pub mod channel;
|
||||
pub mod circle;
|
||||
pub mod constraints;
|
||||
pub mod fft;
|
||||
pub mod fields;
|
||||
pub mod fri;
|
||||
pub mod lookups;
|
||||
pub mod pcs;
|
||||
pub mod poly;
|
||||
pub mod proof_of_work;
|
||||
pub mod prover;
|
||||
pub mod queries;
|
||||
#[cfg(test)]
|
||||
pub mod test_utils;
|
||||
pub mod utils;
|
||||
pub mod vcs;
|
||||
|
||||
/// A vector in which each element relates (by index) to a column in the trace.
|
||||
pub type ColumnVec<T> = Vec<T>;
|
||||
|
||||
/// A vector of [ColumnVec]s. Each [ColumnVec] relates (by index) to a component in the air.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ComponentVec<T>(pub Vec<ColumnVec<T>>);
|
||||
|
||||
impl<T> ComponentVec<T> {
|
||||
pub fn flatten(self) -> ColumnVec<T> {
|
||||
self.0.into_iter().flatten().collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ComponentVec<ColumnVec<T>> {
|
||||
pub fn flatten_cols(self) -> Vec<T> {
|
||||
self.0.into_iter().flatten().flatten().collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Default for ComponentVec<T> {
|
||||
fn default() -> Self {
|
||||
Self(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Deref for ComponentVec<T> {
|
||||
type Target = Vec<ColumnVec<T>>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> DerefMut for ComponentVec<T> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
||||
40
Stwo_wrapper/crates/prover/src/core/pcs/mod.rs
Normal file
40
Stwo_wrapper/crates/prover/src/core/pcs/mod.rs
Normal file
@ -0,0 +1,40 @@
|
||||
//! Implements a FRI polynomial commitment scheme.
|
||||
//! This is a protocol where the prover can commit on a set of polynomials and then prove their
|
||||
//! opening on a set of points.
|
||||
//! Note: This implementation is not really a polynomial commitment scheme, because we are not in
|
||||
//! the unique decoding regime. This is enough for a STARK proof though, where we only want to imply
|
||||
//! the existence of such polynomials, and are ok with having a small decoding list.
|
||||
//! Note: Opened points cannot come from the commitment domain.
|
||||
|
||||
mod prover;
|
||||
pub mod quotients;
|
||||
mod utils;
|
||||
mod verifier;
|
||||
|
||||
pub use self::prover::{
|
||||
CommitmentSchemeProof, CommitmentSchemeProver, CommitmentTreeProver, TreeBuilder,
|
||||
};
|
||||
pub use self::utils::TreeVec;
|
||||
pub use self::verifier::CommitmentSchemeVerifier;
|
||||
use super::fri::FriConfig;
|
||||
|
||||
#[derive(Copy, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct TreeSubspan {
|
||||
pub tree_index: usize,
|
||||
pub col_start: usize,
|
||||
pub col_end: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct PcsConfig {
|
||||
pub pow_bits: u32,
|
||||
pub fri_config: FriConfig,
|
||||
}
|
||||
impl Default for PcsConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
pow_bits: 5,
|
||||
fri_config: FriConfig::new(0, 1, 3),
|
||||
}
|
||||
}
|
||||
}
|
||||
256
Stwo_wrapper/crates/prover/src/core/pcs/prover.rs
Normal file
256
Stwo_wrapper/crates/prover/src/core/pcs/prover.rs
Normal file
@ -0,0 +1,256 @@
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use itertools::Itertools;
|
||||
use tracing::{span, Level};
|
||||
|
||||
use super::super::circle::CirclePoint;
|
||||
use super::super::fields::m31::BaseField;
|
||||
use super::super::fields::qm31::SecureField;
|
||||
use super::super::fri::{FriProof, FriProver};
|
||||
use super::super::poly::circle::CanonicCoset;
|
||||
use super::super::poly::BitReversedOrder;
|
||||
use super::super::ColumnVec;
|
||||
use super::quotients::{compute_fri_quotients, PointSample};
|
||||
use super::utils::TreeVec;
|
||||
use super::{PcsConfig, TreeSubspan};
|
||||
use crate::core::air::Trace;
|
||||
use crate::core::backend::BackendForChannel;
|
||||
use crate::core::channel::{Channel, MerkleChannel};
|
||||
use crate::core::poly::circle::{CircleEvaluation, CirclePoly};
|
||||
use crate::core::poly::twiddles::TwiddleTree;
|
||||
use crate::core::vcs::ops::MerkleHasher;
|
||||
use crate::core::vcs::prover::{MerkleDecommitment, MerkleProver};
|
||||
|
||||
/// The prover side of a FRI polynomial commitment scheme. See [super].
|
||||
pub struct CommitmentSchemeProver<'a, B: BackendForChannel<MC>, MC: MerkleChannel> {
|
||||
pub trees: TreeVec<CommitmentTreeProver<B, MC>>,
|
||||
pub config: PcsConfig,
|
||||
twiddles: &'a TwiddleTree<B>,
|
||||
}
|
||||
|
||||
impl<'a, B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentSchemeProver<'a, B, MC> {
|
||||
pub fn new(config: PcsConfig, twiddles: &'a TwiddleTree<B>) -> Self {
|
||||
CommitmentSchemeProver {
|
||||
trees: TreeVec::default(),
|
||||
config,
|
||||
twiddles,
|
||||
}
|
||||
}
|
||||
|
||||
fn commit(&mut self, polynomials: ColumnVec<CirclePoly<B>>, channel: &mut MC::C) {
|
||||
let _span = span!(Level::INFO, "Commitment").entered();
|
||||
let tree = CommitmentTreeProver::new(
|
||||
polynomials,
|
||||
self.config.fri_config.log_blowup_factor,
|
||||
channel,
|
||||
self.twiddles,
|
||||
);
|
||||
self.trees.push(tree);
|
||||
}
|
||||
|
||||
pub fn tree_builder(&mut self) -> TreeBuilder<'_, 'a, B, MC> {
|
||||
TreeBuilder {
|
||||
tree_index: self.trees.len(),
|
||||
commitment_scheme: self,
|
||||
polys: Vec::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn roots(&self) -> TreeVec<<MC::H as MerkleHasher>::Hash> {
|
||||
self.trees.as_ref().map(|tree| tree.commitment.root())
|
||||
}
|
||||
|
||||
pub fn polynomials(&self) -> TreeVec<ColumnVec<&CirclePoly<B>>> {
|
||||
self.trees
|
||||
.as_ref()
|
||||
.map(|tree| tree.polynomials.iter().collect())
|
||||
}
|
||||
|
||||
pub fn evaluations(
|
||||
&self,
|
||||
) -> TreeVec<ColumnVec<&CircleEvaluation<B, BaseField, BitReversedOrder>>> {
|
||||
self.trees
|
||||
.as_ref()
|
||||
.map(|tree| tree.evaluations.iter().collect())
|
||||
}
|
||||
|
||||
pub fn trace(&self) -> Trace<'_, B> {
|
||||
let polys = self.polynomials();
|
||||
let evals = self.evaluations();
|
||||
Trace { polys, evals }
|
||||
}
|
||||
|
||||
pub fn prove_values(
|
||||
&self,
|
||||
sampled_points: TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>>,
|
||||
channel: &mut MC::C,
|
||||
) -> CommitmentSchemeProof<MC::H> {
|
||||
// Evaluate polynomials on open points.
|
||||
let span = span!(Level::INFO, "Evaluate columns out of domain").entered();
|
||||
let samples = self
|
||||
.polynomials()
|
||||
.zip_cols(&sampled_points)
|
||||
.map_cols(|(poly, points)| {
|
||||
points
|
||||
.iter()
|
||||
.map(|&point| PointSample {
|
||||
point,
|
||||
value: poly.eval_at_point(point),
|
||||
})
|
||||
.collect_vec()
|
||||
});
|
||||
span.exit();
|
||||
let sampled_values = samples
|
||||
.as_cols_ref()
|
||||
.map_cols(|x| x.iter().map(|o| o.value).collect());
|
||||
channel.mix_felts(&sampled_values.clone().flatten_cols());
|
||||
|
||||
// Compute oods quotients for boundary constraints on the sampled points.
|
||||
let columns = self.evaluations().flatten();
|
||||
let quotients = compute_fri_quotients(
|
||||
&columns,
|
||||
&samples.flatten(),
|
||||
channel.draw_felt(),
|
||||
self.config.fri_config.log_blowup_factor,
|
||||
);
|
||||
|
||||
// Run FRI commitment phase on the oods quotients.
|
||||
let fri_prover =
|
||||
FriProver::<B, MC>::commit(channel, self.config.fri_config, "ients, self.twiddles);
|
||||
|
||||
// Proof of work.
|
||||
let span1 = span!(Level::INFO, "Grind").entered();
|
||||
let proof_of_work = B::grind(channel, self.config.pow_bits);
|
||||
span1.exit();
|
||||
channel.mix_nonce(proof_of_work);
|
||||
|
||||
// FRI decommitment phase.
|
||||
let (fri_proof, fri_query_domains) = fri_prover.decommit(channel);
|
||||
|
||||
// Decommit the FRI queries on the merkle trees.
|
||||
let decommitment_results = self.trees.as_ref().map(|tree| {
|
||||
let queries = fri_query_domains
|
||||
.iter()
|
||||
.map(|(&log_size, domain)| (log_size, domain.flatten()))
|
||||
.collect();
|
||||
tree.decommit(queries)
|
||||
});
|
||||
|
||||
let queried_values = decommitment_results.as_ref().map(|(v, _)| v.clone());
|
||||
let decommitments = decommitment_results.map(|(_, d)| d);
|
||||
|
||||
CommitmentSchemeProof {
|
||||
sampled_values,
|
||||
decommitments,
|
||||
queried_values,
|
||||
proof_of_work,
|
||||
fri_proof,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CommitmentSchemeProof<H: MerkleHasher> {
|
||||
pub sampled_values: TreeVec<ColumnVec<Vec<SecureField>>>,
|
||||
pub decommitments: TreeVec<MerkleDecommitment<H>>,
|
||||
pub queried_values: TreeVec<ColumnVec<Vec<BaseField>>>,
|
||||
pub proof_of_work: u64,
|
||||
pub fri_proof: FriProof<H>,
|
||||
}
|
||||
|
||||
pub struct TreeBuilder<'a, 'b, B: BackendForChannel<MC>, MC: MerkleChannel> {
|
||||
tree_index: usize,
|
||||
commitment_scheme: &'a mut CommitmentSchemeProver<'b, B, MC>,
|
||||
polys: ColumnVec<CirclePoly<B>>,
|
||||
}
|
||||
impl<'a, 'b, B: BackendForChannel<MC>, MC: MerkleChannel> TreeBuilder<'a, 'b, B, MC> {
|
||||
pub fn extend_evals(
|
||||
&mut self,
|
||||
columns: ColumnVec<CircleEvaluation<B, BaseField, BitReversedOrder>>,
|
||||
) -> TreeSubspan {
|
||||
let span = span!(Level::INFO, "Interpolation for commitment").entered();
|
||||
let col_start = self.polys.len();
|
||||
let polys = columns
|
||||
.into_iter()
|
||||
.map(|eval| eval.interpolate_with_twiddles(self.commitment_scheme.twiddles))
|
||||
.collect_vec();
|
||||
span.exit();
|
||||
self.polys.extend(polys);
|
||||
TreeSubspan {
|
||||
tree_index: self.tree_index,
|
||||
col_start,
|
||||
col_end: self.polys.len(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn extend_polys(&mut self, polys: ColumnVec<CirclePoly<B>>) -> TreeSubspan {
|
||||
let col_start = self.polys.len();
|
||||
self.polys.extend(polys);
|
||||
TreeSubspan {
|
||||
tree_index: self.tree_index,
|
||||
col_start,
|
||||
col_end: self.polys.len(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn commit(self, channel: &mut MC::C) {
|
||||
let _span = span!(Level::INFO, "Commitment").entered();
|
||||
self.commitment_scheme.commit(self.polys, channel);
|
||||
}
|
||||
}
|
||||
|
||||
/// Prover data for a single commitment tree in a commitment scheme. The commitment scheme allows to
|
||||
/// commit on a set of polynomials at a time. This corresponds to such a set.
|
||||
pub struct CommitmentTreeProver<B: BackendForChannel<MC>, MC: MerkleChannel> {
|
||||
pub polynomials: ColumnVec<CirclePoly<B>>,
|
||||
pub evaluations: ColumnVec<CircleEvaluation<B, BaseField, BitReversedOrder>>,
|
||||
pub commitment: MerkleProver<B, MC::H>,
|
||||
}
|
||||
|
||||
impl<B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentTreeProver<B, MC> {
|
||||
pub fn new(
|
||||
polynomials: ColumnVec<CirclePoly<B>>,
|
||||
log_blowup_factor: u32,
|
||||
channel: &mut MC::C,
|
||||
twiddles: &TwiddleTree<B>,
|
||||
) -> Self {
|
||||
let span = span!(Level::INFO, "Extension").entered();
|
||||
let evaluations = polynomials
|
||||
.iter()
|
||||
.map(|poly| {
|
||||
poly.evaluate_with_twiddles(
|
||||
CanonicCoset::new(poly.log_size() + log_blowup_factor).circle_domain(),
|
||||
twiddles,
|
||||
)
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
span.exit();
|
||||
|
||||
let _span = span!(Level::INFO, "Merkle").entered();
|
||||
let tree = MerkleProver::commit(evaluations.iter().map(|eval| &eval.values).collect());
|
||||
MC::mix_root(channel, tree.root());
|
||||
|
||||
CommitmentTreeProver {
|
||||
polynomials,
|
||||
evaluations,
|
||||
commitment: tree,
|
||||
}
|
||||
}
|
||||
|
||||
/// Decommits the merkle tree on the given query positions.
|
||||
/// Returns the values at the queried positions and the decommitment.
|
||||
/// The queries are given as a mapping from the log size of the layer size to the queried
|
||||
/// positions on each column of that size.
|
||||
fn decommit(
|
||||
&self,
|
||||
queries: BTreeMap<u32, Vec<usize>>,
|
||||
) -> (ColumnVec<Vec<BaseField>>, MerkleDecommitment<MC::H>) {
|
||||
let eval_vec = self
|
||||
.evaluations
|
||||
.iter()
|
||||
.map(|eval| &eval.values)
|
||||
.collect_vec();
|
||||
self.commitment.decommit(queries, eval_vec)
|
||||
}
|
||||
}
|
||||
218
Stwo_wrapper/crates/prover/src/core/pcs/quotients.rs
Normal file
218
Stwo_wrapper/crates/prover/src/core/pcs/quotients.rs
Normal file
@ -0,0 +1,218 @@
|
||||
use std::cmp::Reverse;
|
||||
use std::collections::BTreeMap;
|
||||
use std::iter::zip;
|
||||
|
||||
use itertools::{izip, multiunzip, Itertools};
|
||||
use tracing::{span, Level};
|
||||
|
||||
use crate::core::backend::cpu::quotients::{accumulate_row_quotients, quotient_constants};
|
||||
use crate::core::circle::CirclePoint;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::fri::SparseCircleEvaluation;
|
||||
use crate::core::poly::circle::{
|
||||
CanonicCoset, CircleDomain, CircleEvaluation, PolyOps, SecureEvaluation,
|
||||
};
|
||||
use crate::core::poly::BitReversedOrder;
|
||||
use crate::core::prover::VerificationError;
|
||||
use crate::core::queries::SparseSubCircleDomain;
|
||||
use crate::core::utils::bit_reverse_index;
|
||||
|
||||
pub trait QuotientOps: PolyOps {
|
||||
/// Accumulates the quotients of the columns at the given domain.
|
||||
/// For a column f(x), and a point sample (p,v), the quotient is
|
||||
/// (f(x) - V0(x))/V1(x)
|
||||
/// where V0(p)=v, V0(conj(p))=conj(v), and V1 is a vanishing polynomial for p,conj(p).
|
||||
/// This ensures that if f(p)=v, then the quotient is a polynomial.
|
||||
/// The result is a linear combination of the quotients using powers of random_coeff.
|
||||
fn accumulate_quotients(
|
||||
domain: CircleDomain,
|
||||
columns: &[&CircleEvaluation<Self, BaseField, BitReversedOrder>],
|
||||
random_coeff: SecureField,
|
||||
sample_batches: &[ColumnSampleBatch],
|
||||
log_blowup_factor: u32,
|
||||
) -> SecureEvaluation<Self, BitReversedOrder>;
|
||||
}
|
||||
|
||||
/// A batch of column samplings at a point.
|
||||
pub struct ColumnSampleBatch {
|
||||
/// The point at which the columns are sampled.
|
||||
pub point: CirclePoint<SecureField>,
|
||||
/// The sampled column indices and their values at the point.
|
||||
pub columns_and_values: Vec<(usize, SecureField)>,
|
||||
}
|
||||
|
||||
impl ColumnSampleBatch {
|
||||
/// Groups column samples by sampled point.
|
||||
/// # Arguments
|
||||
/// samples: For each column, a vector of samples.
|
||||
pub fn new_vec(samples: &[&Vec<PointSample>]) -> Vec<Self> {
|
||||
// Group samples by point, and create a ColumnSampleBatch for each point.
|
||||
// This should keep a stable ordering.
|
||||
let mut grouped_samples = BTreeMap::new();
|
||||
for (column_index, samples) in samples.iter().enumerate() {
|
||||
for sample in samples.iter() {
|
||||
grouped_samples
|
||||
.entry(sample.point)
|
||||
.or_insert_with(Vec::new)
|
||||
.push((column_index, sample.value));
|
||||
}
|
||||
}
|
||||
grouped_samples
|
||||
.into_iter()
|
||||
.map(|(point, columns_and_values)| ColumnSampleBatch {
|
||||
point,
|
||||
columns_and_values,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PointSample {
|
||||
pub point: CirclePoint<SecureField>,
|
||||
pub value: SecureField,
|
||||
}
|
||||
|
||||
pub fn compute_fri_quotients<B: QuotientOps>(
|
||||
columns: &[&CircleEvaluation<B, BaseField, BitReversedOrder>],
|
||||
samples: &[Vec<PointSample>],
|
||||
random_coeff: SecureField,
|
||||
log_blowup_factor: u32,
|
||||
) -> Vec<SecureEvaluation<B, BitReversedOrder>> {
|
||||
let _span = span!(Level::INFO, "Compute FRI quotients").entered();
|
||||
zip(columns, samples)
|
||||
.sorted_by_key(|(c, _)| Reverse(c.domain.log_size()))
|
||||
.group_by(|(c, _)| c.domain.log_size())
|
||||
.into_iter()
|
||||
.map(|(log_size, tuples)| {
|
||||
let (columns, samples): (Vec<_>, Vec<_>) = tuples.unzip();
|
||||
let domain = CanonicCoset::new(log_size).circle_domain();
|
||||
// TODO: slice.
|
||||
let sample_batches = ColumnSampleBatch::new_vec(&samples);
|
||||
B::accumulate_quotients(
|
||||
domain,
|
||||
&columns,
|
||||
random_coeff,
|
||||
&sample_batches,
|
||||
log_blowup_factor,
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn fri_answers(
|
||||
column_log_sizes: Vec<u32>,
|
||||
samples: &[Vec<PointSample>],
|
||||
random_coeff: SecureField,
|
||||
query_domain_per_log_size: BTreeMap<u32, SparseSubCircleDomain>,
|
||||
queried_values_per_column: &[Vec<BaseField>],
|
||||
) -> Result<Vec<SparseCircleEvaluation>, VerificationError> {
|
||||
izip!(column_log_sizes, samples, queried_values_per_column)
|
||||
.sorted_by_key(|(log_size, ..)| Reverse(*log_size))
|
||||
.group_by(|(log_size, ..)| *log_size)
|
||||
.into_iter()
|
||||
.map(|(log_size, tuples)| {
|
||||
let (_, samples, queried_valued_per_column): (Vec<_>, Vec<_>, Vec<_>) =
|
||||
multiunzip(tuples);
|
||||
fri_answers_for_log_size(
|
||||
log_size,
|
||||
&samples, // Here it's the points and sampled values (in the proof)
|
||||
random_coeff,
|
||||
&query_domain_per_log_size[&log_size],
|
||||
&queried_valued_per_column, // Here it's queried values (in the proof)
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn fri_answers_for_log_size(
|
||||
log_size: u32,
|
||||
samples: &[&Vec<PointSample>],
|
||||
random_coeff: SecureField,
|
||||
query_domain: &SparseSubCircleDomain,
|
||||
queried_values_per_column: &[&Vec<BaseField>],
|
||||
) -> Result<SparseCircleEvaluation, VerificationError> {
|
||||
let commitment_domain = CanonicCoset::new(log_size).circle_domain();
|
||||
let sample_batches = ColumnSampleBatch::new_vec(samples);
|
||||
for queried_values in queried_values_per_column {
|
||||
if queried_values.len() != query_domain.flatten().len() {
|
||||
return Err(VerificationError::InvalidStructure(
|
||||
"Insufficient number of queried values".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
let mut queried_values_per_column = queried_values_per_column
|
||||
.iter()
|
||||
.map(|q| q.iter())
|
||||
.collect_vec();
|
||||
|
||||
let mut evals = Vec::new();
|
||||
for subdomain in query_domain.iter() {
|
||||
let domain = subdomain.to_circle_domain(&commitment_domain);
|
||||
let quotient_constants = quotient_constants(&sample_batches, random_coeff, domain);
|
||||
let mut column_evals = Vec::new();
|
||||
for queried_values in queried_values_per_column.iter_mut() {
|
||||
let eval = CircleEvaluation::new(
|
||||
domain,
|
||||
queried_values.take(domain.size()).copied().collect_vec(),
|
||||
);
|
||||
column_evals.push(eval);
|
||||
}
|
||||
|
||||
let mut values = Vec::new();
|
||||
for row in 0..domain.size() {
|
||||
let domain_point = domain.at(bit_reverse_index(row, log_size));
|
||||
let value = accumulate_row_quotients(
|
||||
&sample_batches,
|
||||
&column_evals.iter().collect_vec(),
|
||||
"ient_constants,
|
||||
row,
|
||||
domain_point,
|
||||
);
|
||||
values.push(value);
|
||||
}
|
||||
let eval = CircleEvaluation::new(domain, values);
|
||||
evals.push(eval);
|
||||
}
|
||||
|
||||
let res = SparseCircleEvaluation::new(evals);
|
||||
if !queried_values_per_column.iter().all(|x| x.is_empty()) {
|
||||
return Err(VerificationError::InvalidStructure(
|
||||
"Too many queried values".to_string(),
|
||||
));
|
||||
}
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::core::backend::cpu::{CpuCircleEvaluation, CpuCirclePoly};
|
||||
use crate::core::circle::SECURE_FIELD_CIRCLE_GEN;
|
||||
use crate::core::pcs::quotients::{compute_fri_quotients, PointSample};
|
||||
use crate::core::poly::circle::CanonicCoset;
|
||||
use crate::{m31, qm31};
|
||||
|
||||
#[test]
|
||||
fn test_quotients_are_low_degree() {
|
||||
const LOG_SIZE: u32 = 7;
|
||||
const LOG_BLOWUP_FACTOR: u32 = 1;
|
||||
let polynomial = CpuCirclePoly::new((0..1 << LOG_SIZE).map(|i| m31!(i)).collect());
|
||||
let eval_domain = CanonicCoset::new(LOG_SIZE + 1).circle_domain();
|
||||
let eval = polynomial.evaluate(eval_domain);
|
||||
let point = SECURE_FIELD_CIRCLE_GEN;
|
||||
let value = polynomial.eval_at_point(point);
|
||||
let coeff = qm31!(1, 2, 3, 4);
|
||||
let quot_eval = compute_fri_quotients(
|
||||
&[&eval],
|
||||
&[vec![PointSample { point, value }]],
|
||||
coeff,
|
||||
LOG_BLOWUP_FACTOR,
|
||||
)
|
||||
.pop()
|
||||
.unwrap();
|
||||
let quot_poly_base_field =
|
||||
CpuCircleEvaluation::new(eval_domain, quot_eval.values.columns[0].clone())
|
||||
.interpolate();
|
||||
assert!(quot_poly_base_field.is_in_fri_space(LOG_SIZE));
|
||||
}
|
||||
}
|
||||
158
Stwo_wrapper/crates/prover/src/core/pcs/utils.rs
Normal file
158
Stwo_wrapper/crates/prover/src/core/pcs/utils.rs
Normal file
@ -0,0 +1,158 @@
|
||||
use std::collections::BTreeSet;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
|
||||
use itertools::zip_eq;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::TreeSubspan;
|
||||
use crate::core::ColumnVec;
|
||||
|
||||
/// A container that holds an element for each commitment tree.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TreeVec<T>(pub Vec<T>);
|
||||
|
||||
impl<T> TreeVec<T> {
|
||||
pub fn new(vec: Vec<T>) -> TreeVec<T> {
|
||||
TreeVec(vec)
|
||||
}
|
||||
pub fn map<U, F: Fn(T) -> U>(self, f: F) -> TreeVec<U> {
|
||||
TreeVec(self.0.into_iter().map(f).collect())
|
||||
}
|
||||
pub fn zip<U>(self, other: impl Into<TreeVec<U>>) -> TreeVec<(T, U)> {
|
||||
let other = other.into();
|
||||
TreeVec(self.0.into_iter().zip(other.0).collect())
|
||||
}
|
||||
pub fn zip_eq<U>(self, other: impl Into<TreeVec<U>>) -> TreeVec<(T, U)> {
|
||||
let other = other.into();
|
||||
TreeVec(zip_eq(self.0, other.0).collect())
|
||||
}
|
||||
pub fn as_ref(&self) -> TreeVec<&T> {
|
||||
TreeVec(self.iter().collect())
|
||||
}
|
||||
pub fn as_mut(&mut self) -> TreeVec<&mut T> {
|
||||
TreeVec(self.iter_mut().collect())
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts `&TreeVec<T>` to `TreeVec<&T>`.
|
||||
impl<'a, T> From<&'a TreeVec<T>> for TreeVec<&'a T> {
|
||||
fn from(val: &'a TreeVec<T>) -> Self {
|
||||
val.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Deref for TreeVec<T> {
|
||||
type Target = Vec<T>;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> DerefMut for TreeVec<T> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Default for TreeVec<T> {
|
||||
fn default() -> Self {
|
||||
TreeVec(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> TreeVec<ColumnVec<T>> {
|
||||
pub fn map_cols<U, F: FnMut(T) -> U>(self, mut f: F) -> TreeVec<ColumnVec<U>> {
|
||||
TreeVec(
|
||||
self.0
|
||||
.into_iter()
|
||||
.map(|column| column.into_iter().map(&mut f).collect())
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Zips two [`TreeVec<ColumVec<T>>`] with the same structure (number of columns in each tree).
|
||||
/// The resulting [`TreeVec<ColumVec<T>>`] has the same structure, with each value being a tuple
|
||||
/// of the corresponding values from the input [`TreeVec<ColumVec<T>>`].
|
||||
pub fn zip_cols<U>(
|
||||
self,
|
||||
other: impl Into<TreeVec<ColumnVec<U>>>,
|
||||
) -> TreeVec<ColumnVec<(T, U)>> {
|
||||
let other = other.into();
|
||||
TreeVec(
|
||||
zip_eq(self.0, other.0)
|
||||
.map(|(column1, column2)| zip_eq(column1, column2).collect())
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn as_cols_ref(&self) -> TreeVec<ColumnVec<&T>> {
|
||||
TreeVec(self.iter().map(|column| column.iter().collect()).collect())
|
||||
}
|
||||
|
||||
/// Flattens the [`TreeVec<ColumVec<T>>`] into a single [`ColumnVec`] with all the columns
|
||||
/// combined.
|
||||
pub fn flatten(self) -> ColumnVec<T> {
|
||||
self.0.into_iter().flatten().collect()
|
||||
}
|
||||
|
||||
/// Appends the columns of another [`TreeVec<ColumVec<T>>`] to this one.
|
||||
pub fn append_cols(&mut self, mut other: TreeVec<ColumnVec<T>>) {
|
||||
let n_trees = self.0.len().max(other.0.len());
|
||||
self.0.resize_with(n_trees, Default::default);
|
||||
for (self_col, other_col) in self.0.iter_mut().zip(other.0.iter_mut()) {
|
||||
self_col.append(other_col);
|
||||
}
|
||||
}
|
||||
|
||||
/// Concatenates the columns of multiple [`TreeVec<ColumVec<T>>`] into a single
|
||||
/// [`TreeVec<ColumVec<T>>`].
|
||||
pub fn concat_cols(
|
||||
trees: impl Iterator<Item = TreeVec<ColumnVec<T>>>,
|
||||
) -> TreeVec<ColumnVec<T>> {
|
||||
let mut result = TreeVec::default();
|
||||
for tree in trees {
|
||||
result.append_cols(tree);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// Extracts a sub-tree based on the specified locations.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If two or more locations have the same tree index.
|
||||
pub fn sub_tree(&self, locations: &[TreeSubspan]) -> TreeVec<ColumnVec<&T>> {
|
||||
let tree_indicies: BTreeSet<usize> = locations.iter().map(|l| l.tree_index).collect();
|
||||
assert_eq!(tree_indicies.len(), locations.len());
|
||||
let max_tree_index = tree_indicies.iter().max().unwrap_or(&0);
|
||||
let mut res = TreeVec(vec![Vec::new(); max_tree_index + 1]);
|
||||
|
||||
for &location in locations {
|
||||
// TODO(andrew): Throwing error here might be better instead.
|
||||
let chunk = self.get_chunk(location).unwrap();
|
||||
res[location.tree_index] = chunk;
|
||||
}
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
fn get_chunk(&self, location: TreeSubspan) -> Option<ColumnVec<&T>> {
|
||||
let tree = self.0.get(location.tree_index)?;
|
||||
let chunk = tree.get(location.col_start..location.col_end)?;
|
||||
Some(chunk.iter().collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> From<&'a TreeVec<ColumnVec<T>>> for TreeVec<ColumnVec<&'a T>> {
|
||||
fn from(val: &'a TreeVec<ColumnVec<T>>) -> Self {
|
||||
val.as_cols_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> TreeVec<ColumnVec<Vec<T>>> {
|
||||
/// Flattens a [`TreeVec<ColumVec<T>>`] of [Vec]s into a single [Vec] with all the elements
|
||||
/// combined.
|
||||
pub fn flatten_cols(self) -> Vec<T> {
|
||||
self.0.into_iter().flatten().flatten().collect()
|
||||
}
|
||||
}
|
||||
134
Stwo_wrapper/crates/prover/src/core/pcs/verifier.rs
Normal file
134
Stwo_wrapper/crates/prover/src/core/pcs/verifier.rs
Normal file
@ -0,0 +1,134 @@
|
||||
use std::iter::zip;
|
||||
|
||||
use itertools::Itertools;
|
||||
|
||||
use super::super::circle::CirclePoint;
|
||||
use super::super::fields::qm31::SecureField;
|
||||
use super::super::fri::{CirclePolyDegreeBound, FriVerifier};
|
||||
use super::quotients::{fri_answers, PointSample};
|
||||
use super::utils::TreeVec;
|
||||
use super::{CommitmentSchemeProof, PcsConfig};
|
||||
use crate::core::channel::{Channel, MerkleChannel};
|
||||
use crate::core::prover::VerificationError;
|
||||
use crate::core::vcs::ops::MerkleHasher;
|
||||
use crate::core::vcs::verifier::MerkleVerifier;
|
||||
use crate::core::ColumnVec;
|
||||
|
||||
/// The verifier side of a FRI polynomial commitment scheme. See [super].
|
||||
#[derive(Default)]
|
||||
pub struct CommitmentSchemeVerifier<MC: MerkleChannel> {
|
||||
pub trees: TreeVec<MerkleVerifier<MC::H>>,
|
||||
pub config: PcsConfig,
|
||||
}
|
||||
|
||||
impl<MC: MerkleChannel> CommitmentSchemeVerifier<MC> {
|
||||
pub fn new(config: PcsConfig) -> Self {
|
||||
Self {
|
||||
trees: TreeVec::default(),
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
/// A [TreeVec<ColumnVec>] of the log sizes of each column in each commitment tree.
|
||||
fn column_log_sizes(&self) -> TreeVec<ColumnVec<u32>> {
|
||||
self.trees
|
||||
.as_ref()
|
||||
.map(|tree| tree.column_log_sizes.clone())
|
||||
}
|
||||
|
||||
/// Reads a commitment from the prover.
|
||||
pub fn commit(
|
||||
&mut self,
|
||||
commitment: <MC::H as MerkleHasher>::Hash,
|
||||
log_sizes: &[u32],
|
||||
channel: &mut MC::C,
|
||||
) {
|
||||
MC::mix_root(channel, commitment);
|
||||
let extended_log_sizes = log_sizes
|
||||
.iter()
|
||||
.map(|&log_size| log_size + self.config.fri_config.log_blowup_factor)
|
||||
.collect();
|
||||
let verifier = MerkleVerifier::new(commitment, extended_log_sizes);
|
||||
self.trees.push(verifier);
|
||||
}
|
||||
|
||||
pub fn verify_values(
|
||||
&self,
|
||||
sampled_points: TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>>,
|
||||
proof: CommitmentSchemeProof<MC::H>,
|
||||
channel: &mut MC::C,
|
||||
) -> Result<(), VerificationError> {
|
||||
|
||||
channel.mix_felts(&proof.sampled_values.clone().flatten_cols());
|
||||
let random_coeff = channel.draw_felt();
|
||||
|
||||
let bounds = self
|
||||
.column_log_sizes()
|
||||
.zip_cols(&sampled_points)
|
||||
.map_cols(|(log_size, sampled_points)| {
|
||||
vec![
|
||||
CirclePolyDegreeBound::new(log_size - self.config.fri_config.log_blowup_factor);
|
||||
sampled_points.len()
|
||||
]
|
||||
})
|
||||
.flatten_cols()
|
||||
.into_iter()
|
||||
.sorted()
|
||||
.rev()
|
||||
.dedup()
|
||||
.collect_vec();
|
||||
|
||||
// FRI commitment phase on OODS quotients.
|
||||
let mut fri_verifier =
|
||||
FriVerifier::<MC>::commit(channel, self.config.fri_config, proof.fri_proof, bounds)?;
|
||||
|
||||
// Verify proof of work.
|
||||
channel.mix_nonce(proof.proof_of_work);
|
||||
if channel.trailing_zeros() < self.config.pow_bits {
|
||||
return Err(VerificationError::ProofOfWork);
|
||||
}
|
||||
|
||||
// Get FRI query domains.
|
||||
let fri_query_domains = fri_verifier.column_query_positions(channel);
|
||||
|
||||
// Verify merkle decommitments.
|
||||
self.trees
|
||||
.as_ref()
|
||||
.zip_eq(proof.decommitments)
|
||||
.zip_eq(proof.queried_values.clone())
|
||||
.map(|((tree, decommitment), queried_values)| {
|
||||
let queries = fri_query_domains
|
||||
.iter()
|
||||
.map(|(&log_size, domain)| (log_size, domain.flatten()))
|
||||
.collect();
|
||||
tree.verify(queries, queried_values, decommitment)
|
||||
})
|
||||
.0
|
||||
.into_iter()
|
||||
.collect::<Result<_, _>>()?;
|
||||
println!("DONE");
|
||||
|
||||
// Answer FRI queries.
|
||||
let samples = sampled_points // Sample point is always the same and the value is the one provided in the proof
|
||||
.zip_cols(proof.sampled_values)
|
||||
.map_cols(|(sampled_points, sampled_values)| {
|
||||
zip(sampled_points, sampled_values)
|
||||
.map(|(point, value)| PointSample { point, value })
|
||||
.collect_vec()
|
||||
})
|
||||
.flatten();
|
||||
|
||||
// TODO(spapini): Properly defined column log size and dinstinguish between poly and
|
||||
// commitment.
|
||||
let fri_answers = fri_answers(
|
||||
self.column_log_sizes().flatten().into_iter().collect(),
|
||||
&samples,
|
||||
random_coeff,
|
||||
fri_query_domains,
|
||||
&proof.queried_values.flatten(),
|
||||
)?;
|
||||
|
||||
fri_verifier.decommit(fri_answers)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
77
Stwo_wrapper/crates/prover/src/core/poly/circle/canonic.rs
Normal file
77
Stwo_wrapper/crates/prover/src/core/poly/circle/canonic.rs
Normal file
@ -0,0 +1,77 @@
|
||||
use super::CircleDomain;
|
||||
use crate::core::circle::{CirclePoint, CirclePointIndex, Coset};
|
||||
use crate::core::fields::m31::BaseField;
|
||||
|
||||
/// A coset of the form G_{2n} + <G_n>, where G_n is the generator of the
|
||||
/// subgroup of order n. The ordering on this coset is G_2n + i * G_n.
|
||||
/// These cosets can be used as a [CircleDomain], and be interpolated on.
|
||||
/// Note that this changes the ordering on the coset to be like [CircleDomain],
|
||||
/// which is G_2n + i * G_n/2 and then -G_2n -i * G_n/2.
|
||||
/// For example, the Xs below are a canonic coset with n=8.
|
||||
/// ```text
|
||||
/// X O X
|
||||
/// O O
|
||||
/// X X
|
||||
/// O O
|
||||
/// X X
|
||||
/// O O
|
||||
/// X O X
|
||||
/// ```
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
|
||||
pub struct CanonicCoset {
|
||||
pub coset: Coset,
|
||||
}
|
||||
|
||||
impl CanonicCoset {
|
||||
pub fn new(log_size: u32) -> Self {
|
||||
assert!(log_size > 0);
|
||||
Self {
|
||||
coset: Coset::odds(log_size),
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the full coset represented G_{2n} + <G_n>.
|
||||
pub fn coset(&self) -> Coset {
|
||||
self.coset
|
||||
}
|
||||
|
||||
/// Gets half of the coset (its conjugate complements to the whole coset), G_{2n} + <G_{n/2}>
|
||||
pub fn half_coset(&self) -> Coset {
|
||||
Coset::half_odds(self.log_size() - 1)
|
||||
}
|
||||
|
||||
/// Gets the [CircleDomain] representing the same point set (in another order).
|
||||
pub fn circle_domain(&self) -> CircleDomain {
|
||||
CircleDomain::new(self.half_coset())
|
||||
}
|
||||
|
||||
/// Returns the log size of the coset.
|
||||
pub fn log_size(&self) -> u32 {
|
||||
self.coset.log_size
|
||||
}
|
||||
|
||||
/// Returns the size of the coset.
|
||||
pub fn size(&self) -> usize {
|
||||
self.coset.size()
|
||||
}
|
||||
|
||||
pub fn initial_index(&self) -> CirclePointIndex {
|
||||
self.coset.initial_index
|
||||
}
|
||||
|
||||
pub fn step_size(&self) -> CirclePointIndex {
|
||||
self.coset.step_size
|
||||
}
|
||||
|
||||
pub fn step(&self) -> CirclePoint<BaseField> {
|
||||
self.coset.step
|
||||
}
|
||||
|
||||
pub fn index_at(&self, index: usize) -> CirclePointIndex {
|
||||
self.coset.index_at(index)
|
||||
}
|
||||
|
||||
pub fn at(&self, i: usize) -> CirclePoint<BaseField> {
|
||||
self.coset.at(i)
|
||||
}
|
||||
}
|
||||
188
Stwo_wrapper/crates/prover/src/core/poly/circle/domain.rs
Normal file
188
Stwo_wrapper/crates/prover/src/core/poly/circle/domain.rs
Normal file
@ -0,0 +1,188 @@
|
||||
use std::iter::Chain;
|
||||
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::core::circle::{
|
||||
CirclePoint, CirclePointIndex, Coset, CosetIterator, M31_CIRCLE_LOG_ORDER,
|
||||
};
|
||||
use crate::core::fields::m31::BaseField;
|
||||
|
||||
pub const MAX_CIRCLE_DOMAIN_LOG_SIZE: u32 = M31_CIRCLE_LOG_ORDER - 1;
|
||||
|
||||
/// A valid domain for circle polynomial interpolation and evaluation.
|
||||
/// Valid domains are a disjoint union of two conjugate cosets: +-C + <G_n>.
|
||||
/// The ordering defined on this domain is C + iG_n, and then -C - iG_n.
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
|
||||
pub struct CircleDomain {
|
||||
pub half_coset: Coset,
|
||||
}
|
||||
|
||||
impl CircleDomain {
|
||||
/// Given a coset C + <G_n>, constructs the circle domain +-C + <G_n> (i.e.,
|
||||
/// this coset and its conjugate).
|
||||
pub fn new(half_coset: Coset) -> Self {
|
||||
Self { half_coset }
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> CircleDomainIterator {
|
||||
self.half_coset
|
||||
.iter()
|
||||
.chain(self.half_coset.conjugate().iter())
|
||||
}
|
||||
|
||||
/// Iterates over point indices.
|
||||
pub fn iter_indices(&self) -> CircleDomainIndexIterator {
|
||||
self.half_coset
|
||||
.iter_indices()
|
||||
.chain(self.half_coset.conjugate().iter_indices())
|
||||
}
|
||||
|
||||
/// Returns the size of the domain.
|
||||
pub fn size(&self) -> usize {
|
||||
1 << self.log_size()
|
||||
}
|
||||
|
||||
/// Returns the log size of the domain.
|
||||
pub fn log_size(&self) -> u32 {
|
||||
self.half_coset.log_size + 1
|
||||
}
|
||||
|
||||
/// Returns the `i` th domain element.
|
||||
pub fn at(&self, i: usize) -> CirclePoint<BaseField> {
|
||||
self.index_at(i).to_point()
|
||||
}
|
||||
|
||||
/// Returns the [CirclePointIndex] of the `i`th domain element.
|
||||
pub fn index_at(&self, i: usize) -> CirclePointIndex {
|
||||
if i < self.half_coset.size() {
|
||||
self.half_coset.index_at(i)
|
||||
} else {
|
||||
-self.half_coset.index_at(i - self.half_coset.size())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn find(&self, i: CirclePointIndex) -> Option<usize> {
|
||||
if let Some(d) = self.half_coset.find(i) {
|
||||
return Some(d);
|
||||
}
|
||||
if let Some(d) = self.half_coset.conjugate().find(i) {
|
||||
return Some(self.half_coset.size() + d);
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Returns true if the domain is canonic.
|
||||
///
|
||||
/// Canonic domains are domains with elements that are the entire set of points defined by
|
||||
/// `G_2n + <G_n>` where `G_n` and `G_2n` are obtained by repeatedly doubling
|
||||
/// [crate::core::circle::M31_CIRCLE_GEN].
|
||||
pub fn is_canonic(&self) -> bool {
|
||||
self.half_coset.initial_index * 4 == self.half_coset.step_size
|
||||
}
|
||||
|
||||
/// Splits a circle domain into a smaller [CircleDomain]s, shifted by offsets.
|
||||
pub fn split(&self, log_parts: u32) -> (CircleDomain, Vec<CirclePointIndex>) {
|
||||
assert!(log_parts <= self.half_coset.log_size);
|
||||
let subdomain = CircleDomain::new(Coset::new(
|
||||
self.half_coset.initial_index,
|
||||
self.half_coset.log_size - log_parts,
|
||||
));
|
||||
let shifts = (0..1 << log_parts)
|
||||
.map(|i| self.half_coset.step_size * i)
|
||||
.collect_vec();
|
||||
(subdomain, shifts)
|
||||
}
|
||||
|
||||
pub fn shift(&self, shift: CirclePointIndex) -> CircleDomain {
|
||||
CircleDomain::new(self.half_coset.shift(shift))
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoIterator for CircleDomain {
|
||||
type Item = CirclePoint<BaseField>;
|
||||
type IntoIter = CircleDomainIterator;
|
||||
|
||||
/// Iterates over the points in the domain.
|
||||
fn into_iter(self) -> CircleDomainIterator {
|
||||
self.iter()
|
||||
}
|
||||
}
|
||||
|
||||
/// An iterator over points in a circle domain.
|
||||
///
|
||||
/// Let the domain be `+-c + <G>`. The first iterated points are `c + <G>`, then `-c + <-G>`.
|
||||
pub type CircleDomainIterator =
|
||||
Chain<CosetIterator<CirclePoint<BaseField>>, CosetIterator<CirclePoint<BaseField>>>;
|
||||
|
||||
/// Like [CircleDomainIterator] but returns corresponding [CirclePointIndex]s.
|
||||
type CircleDomainIndexIterator =
|
||||
Chain<CosetIterator<CirclePointIndex>, CosetIterator<CirclePointIndex>>;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use itertools::Itertools;
|
||||
|
||||
use super::CircleDomain;
|
||||
use crate::core::circle::{CirclePointIndex, Coset};
|
||||
use crate::core::poly::circle::CanonicCoset;
|
||||
|
||||
#[test]
|
||||
fn test_circle_domain_iterator() {
|
||||
let domain = CircleDomain::new(Coset::new(CirclePointIndex::generator(), 2));
|
||||
for (i, point) in domain.iter().enumerate() {
|
||||
if i < 4 {
|
||||
assert_eq!(
|
||||
point,
|
||||
(CirclePointIndex::generator() + CirclePointIndex::subgroup_gen(2) * i)
|
||||
.to_point()
|
||||
);
|
||||
} else {
|
||||
assert_eq!(
|
||||
point,
|
||||
(-(CirclePointIndex::generator() + CirclePointIndex::subgroup_gen(2) * i))
|
||||
.to_point()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_canonic_invalid_domain() {
|
||||
let half_coset = Coset::new(CirclePointIndex::generator(), 4);
|
||||
let not_canonic_domain = CircleDomain::new(half_coset);
|
||||
|
||||
assert!(!not_canonic_domain.is_canonic());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_at_circle_domain() {
|
||||
let domain = CanonicCoset::new(7).circle_domain();
|
||||
let half_domain_size = domain.size() / 2;
|
||||
|
||||
for i in 0..half_domain_size {
|
||||
assert_eq!(domain.index_at(i), -domain.index_at(i + half_domain_size));
|
||||
assert_eq!(domain.at(i), domain.at(i + half_domain_size).conjugate());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_domain_split() {
|
||||
let domain = CanonicCoset::new(5).circle_domain();
|
||||
let (subdomain, shifts) = domain.split(2);
|
||||
|
||||
let domain_points = domain.iter().collect::<Vec<_>>();
|
||||
let points_for_each_domain = shifts
|
||||
.iter()
|
||||
.map(|&shift| (subdomain.shift(shift)).iter().collect_vec())
|
||||
.collect::<Vec<_>>();
|
||||
// Interleave the points from each subdomain.
|
||||
let extended_points = (0..(1 << 3))
|
||||
.flat_map(|point_ind| {
|
||||
(0..(1 << 2))
|
||||
.map(|shift_ind| points_for_each_domain[shift_ind][point_ind])
|
||||
.collect_vec()
|
||||
})
|
||||
.collect_vec();
|
||||
assert_eq!(domain_points, extended_points);
|
||||
}
|
||||
}
|
||||
218
Stwo_wrapper/crates/prover/src/core/poly/circle/evaluation.rs
Normal file
218
Stwo_wrapper/crates/prover/src/core/poly/circle/evaluation.rs
Normal file
@ -0,0 +1,218 @@
|
||||
use std::marker::PhantomData;
|
||||
use std::ops::{Deref, Index};
|
||||
|
||||
use educe::Educe;
|
||||
|
||||
use super::{CanonicCoset, CircleDomain, CirclePoly, PolyOps};
|
||||
use crate::core::backend::cpu::CpuCircleEvaluation;
|
||||
use crate::core::backend::{Col, Column};
|
||||
use crate::core::circle::{CirclePointIndex, Coset};
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::{ExtensionOf, FieldOps};
|
||||
use crate::core::poly::twiddles::TwiddleTree;
|
||||
use crate::core::poly::{BitReversedOrder, NaturalOrder};
|
||||
use crate::core::utils::bit_reverse_index;
|
||||
|
||||
/// An evaluation defined on a [CircleDomain].
|
||||
/// The values are ordered according to the [CircleDomain] ordering.
|
||||
#[derive(Educe)]
|
||||
#[educe(Clone, Debug)]
|
||||
pub struct CircleEvaluation<B: FieldOps<F>, F: ExtensionOf<BaseField>, EvalOrder = NaturalOrder> {
|
||||
pub domain: CircleDomain,
|
||||
pub values: Col<B, F>,
|
||||
_eval_order: PhantomData<EvalOrder>,
|
||||
}
|
||||
|
||||
impl<B: FieldOps<F>, F: ExtensionOf<BaseField>, EvalOrder> CircleEvaluation<B, F, EvalOrder> {
|
||||
pub fn new(domain: CircleDomain, values: Col<B, F>) -> Self {
|
||||
assert_eq!(domain.size(), values.len());
|
||||
Self {
|
||||
domain,
|
||||
values,
|
||||
_eval_order: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Note: The concrete implementation of the poly operations is in the specific backend used.
|
||||
// For example, the CPU backend implementation is in `src/core/backend/cpu/poly.rs`.
|
||||
impl<F: ExtensionOf<BaseField>, B: FieldOps<F>> CircleEvaluation<B, F, NaturalOrder> {
|
||||
// TODO(spapini): Remove. Is this even used.
|
||||
pub fn get_at(&self, point_index: CirclePointIndex) -> F {
|
||||
self.values
|
||||
.at(self.domain.find(point_index).expect("Not in domain"))
|
||||
}
|
||||
|
||||
pub fn bit_reverse(mut self) -> CircleEvaluation<B, F, BitReversedOrder> {
|
||||
B::bit_reverse_column(&mut self.values);
|
||||
CircleEvaluation::new(self.domain, self.values)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: ExtensionOf<BaseField>> CpuCircleEvaluation<F, NaturalOrder> {
|
||||
pub fn fetch_eval_on_coset(&self, coset: Coset) -> CosetSubEvaluation<'_, F> {
|
||||
assert!(coset.log_size() <= self.domain.half_coset.log_size());
|
||||
if let Some(offset) = self.domain.half_coset.find(coset.initial_index) {
|
||||
return CosetSubEvaluation::new(
|
||||
&self.values[..self.domain.half_coset.size()],
|
||||
offset,
|
||||
coset.step_size / self.domain.half_coset.step_size,
|
||||
);
|
||||
}
|
||||
if let Some(offset) = self.domain.half_coset.conjugate().find(coset.initial_index) {
|
||||
return CosetSubEvaluation::new(
|
||||
&self.values[self.domain.half_coset.size()..],
|
||||
offset,
|
||||
(-coset.step_size) / self.domain.half_coset.step_size,
|
||||
);
|
||||
}
|
||||
panic!("Coset not found in domain");
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: PolyOps> CircleEvaluation<B, BaseField, BitReversedOrder> {
|
||||
/// Creates a [CircleEvaluation] from values ordered according to
|
||||
/// [CanonicCoset]. For example, the canonic coset might look like this:
|
||||
/// G_8, G_8 + G_4, G_8 + 2G_4, G_8 + 3G_4.
|
||||
/// The circle domain will be ordered like this:
|
||||
/// G_8, G_8 + 2G_4, -G_8, -G_8 - 2G_4.
|
||||
pub fn new_canonical_ordered(coset: CanonicCoset, values: Col<B, BaseField>) -> Self {
|
||||
B::new_canonical_ordered(coset, values)
|
||||
}
|
||||
|
||||
/// Computes a minimal [CirclePoly] that evaluates to the same values as this evaluation.
|
||||
pub fn interpolate(self) -> CirclePoly<B> {
|
||||
let coset = self.domain.half_coset;
|
||||
B::interpolate(self, &B::precompute_twiddles(coset))
|
||||
}
|
||||
|
||||
/// Computes a minimal [CirclePoly] that evaluates to the same values as this evaluation, using
|
||||
/// precomputed twiddles.
|
||||
pub fn interpolate_with_twiddles(self, twiddles: &TwiddleTree<B>) -> CirclePoly<B> {
|
||||
B::interpolate(self, twiddles)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: FieldOps<F>, F: ExtensionOf<BaseField>> CircleEvaluation<B, F, BitReversedOrder> {
|
||||
pub fn bit_reverse(mut self) -> CircleEvaluation<B, F, NaturalOrder> {
|
||||
B::bit_reverse_column(&mut self.values);
|
||||
CircleEvaluation::new(self.domain, self.values)
|
||||
}
|
||||
|
||||
pub fn get_at(&self, point_index: CirclePointIndex) -> F {
|
||||
self.values.at(bit_reverse_index(
|
||||
self.domain.find(point_index).expect("Not in domain"),
|
||||
self.domain.log_size(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: FieldOps<F>, F: ExtensionOf<BaseField>, EvalOrder> Deref
|
||||
for CircleEvaluation<B, F, EvalOrder>
|
||||
{
|
||||
type Target = Col<B, F>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.values
|
||||
}
|
||||
}
|
||||
|
||||
/// A part of a [CircleEvaluation], for a specific coset that is a subset of the circle domain.
|
||||
pub struct CosetSubEvaluation<'a, F: ExtensionOf<BaseField>> {
|
||||
evaluation: &'a [F],
|
||||
offset: usize,
|
||||
step: isize,
|
||||
}
|
||||
|
||||
impl<'a, F: ExtensionOf<BaseField>> CosetSubEvaluation<'a, F> {
|
||||
fn new(evaluation: &'a [F], offset: usize, step: isize) -> Self {
|
||||
assert!(evaluation.len().is_power_of_two());
|
||||
Self {
|
||||
evaluation,
|
||||
offset,
|
||||
step,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F: ExtensionOf<BaseField>> Index<isize> for CosetSubEvaluation<'a, F> {
|
||||
type Output = F;
|
||||
|
||||
fn index(&self, index: isize) -> &Self::Output {
|
||||
let index =
|
||||
((self.offset as isize) + index * self.step) & ((self.evaluation.len() - 1) as isize);
|
||||
&self.evaluation[index as usize]
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F: ExtensionOf<BaseField>> Index<usize> for CosetSubEvaluation<'a, F> {
|
||||
type Output = F;
|
||||
|
||||
fn index(&self, index: usize) -> &Self::Output {
|
||||
&self[index as isize]
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::core::backend::cpu::CpuCircleEvaluation;
|
||||
use crate::core::circle::Coset;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::poly::circle::CanonicCoset;
|
||||
use crate::core::poly::NaturalOrder;
|
||||
use crate::m31;
|
||||
|
||||
#[test]
|
||||
fn test_interpolate_non_canonic() {
|
||||
let domain = CanonicCoset::new(3).circle_domain();
|
||||
assert_eq!(domain.log_size(), 3);
|
||||
let evaluation = CpuCircleEvaluation::<_, NaturalOrder>::new(
|
||||
domain,
|
||||
(0..8).map(BaseField::from_u32_unchecked).collect(),
|
||||
)
|
||||
.bit_reverse();
|
||||
let poly = evaluation.interpolate();
|
||||
for (i, point) in domain.iter().enumerate() {
|
||||
assert_eq!(poly.eval_at_point(point.into_ef()), m31!(i as u32).into());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_interpolate_canonic() {
|
||||
let coset = CanonicCoset::new(3);
|
||||
let evaluation = CpuCircleEvaluation::new_canonical_ordered(
|
||||
coset,
|
||||
(0..8).map(BaseField::from_u32_unchecked).collect(),
|
||||
);
|
||||
let poly = evaluation.interpolate();
|
||||
for (i, point) in Coset::odds(3).iter().enumerate() {
|
||||
assert_eq!(poly.eval_at_point(point.into_ef()), m31!(i as u32).into());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_get_at_circle_evaluation() {
|
||||
let domain = CanonicCoset::new(7).circle_domain();
|
||||
let values = (0..domain.size()).map(|i| m31!(i as u32)).collect();
|
||||
let circle_evaluation = CpuCircleEvaluation::<_, NaturalOrder>::new(domain, values);
|
||||
let bit_reversed_circle_evaluation = circle_evaluation.clone().bit_reverse();
|
||||
for index in domain.iter_indices() {
|
||||
assert_eq!(
|
||||
circle_evaluation.get_at(index),
|
||||
bit_reversed_circle_evaluation.get_at(index)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sub_evaluation() {
|
||||
let domain = CanonicCoset::new(7).circle_domain();
|
||||
let values = (0..domain.size()).map(|i| m31!(i as u32)).collect();
|
||||
let circle_evaluation = CpuCircleEvaluation::new(domain, values);
|
||||
let coset = Coset::new(domain.index_at(17), 3);
|
||||
let sub_eval = circle_evaluation.fetch_eval_on_coset(coset);
|
||||
for i in 0..coset.size() {
|
||||
assert_eq!(sub_eval[i], circle_evaluation.get_at(coset.index_at(i)));
|
||||
}
|
||||
}
|
||||
}
|
||||
56
Stwo_wrapper/crates/prover/src/core/poly/circle/mod.rs
Normal file
56
Stwo_wrapper/crates/prover/src/core/poly/circle/mod.rs
Normal file
@ -0,0 +1,56 @@
|
||||
mod canonic;
|
||||
mod domain;
|
||||
mod evaluation;
|
||||
mod ops;
|
||||
mod poly;
|
||||
mod secure_poly;
|
||||
|
||||
pub use canonic::CanonicCoset;
|
||||
pub use domain::{CircleDomain, MAX_CIRCLE_DOMAIN_LOG_SIZE};
|
||||
pub use evaluation::{CircleEvaluation, CosetSubEvaluation};
|
||||
pub use ops::PolyOps;
|
||||
pub use poly::CirclePoly;
|
||||
pub use secure_poly::{SecureCirclePoly, SecureEvaluation};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::CanonicCoset;
|
||||
use crate::core::backend::cpu::CpuCircleEvaluation;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::utils::bit_reverse_index;
|
||||
|
||||
#[test]
|
||||
fn test_interpolate_and_eval() {
|
||||
let domain = CanonicCoset::new(3).circle_domain();
|
||||
assert_eq!(domain.log_size(), 3);
|
||||
let evaluation =
|
||||
CpuCircleEvaluation::new(domain, (0..8).map(BaseField::from_u32_unchecked).collect());
|
||||
let poly = evaluation.clone().interpolate();
|
||||
let evaluation2 = poly.evaluate(domain);
|
||||
assert_eq!(evaluation.values, evaluation2.values);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_canonic_valid_domain() {
|
||||
let canonic_domain = CanonicCoset::new(4).circle_domain();
|
||||
|
||||
assert!(canonic_domain.is_canonic());
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_bit_reverse_indices() {
|
||||
let log_domain_size = 7;
|
||||
let log_small_domain_size = 5;
|
||||
let domain = CanonicCoset::new(log_domain_size);
|
||||
let small_domain = CanonicCoset::new(log_small_domain_size);
|
||||
let n_folds = log_domain_size - log_small_domain_size;
|
||||
for i in 0..2usize.pow(log_domain_size) {
|
||||
let point = domain.at(bit_reverse_index(i, log_domain_size));
|
||||
let small_point = small_domain.at(bit_reverse_index(
|
||||
i / 2usize.pow(n_folds),
|
||||
log_small_domain_size,
|
||||
));
|
||||
assert_eq!(point.repeated_double(n_folds), small_point);
|
||||
}
|
||||
}
|
||||
}
|
||||
48
Stwo_wrapper/crates/prover/src/core/poly/circle/ops.rs
Normal file
48
Stwo_wrapper/crates/prover/src/core/poly/circle/ops.rs
Normal file
@ -0,0 +1,48 @@
|
||||
use super::{CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly};
|
||||
use crate::core::backend::Col;
|
||||
use crate::core::circle::{CirclePoint, Coset};
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::fields::FieldOps;
|
||||
use crate::core::poly::twiddles::TwiddleTree;
|
||||
use crate::core::poly::BitReversedOrder;
|
||||
|
||||
/// Operations on BaseField polynomials.
|
||||
pub trait PolyOps: FieldOps<BaseField> + Sized {
|
||||
// TODO(spapini): Use a column instead of this type.
|
||||
/// The type for precomputed twiddles.
|
||||
type Twiddles;
|
||||
|
||||
/// Creates a [CircleEvaluation] from values ordered according to [CanonicCoset].
|
||||
/// Used by the [`CircleEvaluation::new_canonical_ordered()`] function.
|
||||
fn new_canonical_ordered(
|
||||
coset: CanonicCoset,
|
||||
values: Col<Self, BaseField>,
|
||||
) -> CircleEvaluation<Self, BaseField, BitReversedOrder>;
|
||||
|
||||
/// Computes a minimal [CirclePoly] that evaluates to the same values as this evaluation.
|
||||
/// Used by the [`CircleEvaluation::interpolate()`] function.
|
||||
fn interpolate(
|
||||
eval: CircleEvaluation<Self, BaseField, BitReversedOrder>,
|
||||
itwiddles: &TwiddleTree<Self>,
|
||||
) -> CirclePoly<Self>;
|
||||
|
||||
/// Evaluates the polynomial at a single point.
|
||||
/// Used by the [`CirclePoly::eval_at_point()`] function.
|
||||
fn eval_at_point(poly: &CirclePoly<Self>, point: CirclePoint<SecureField>) -> SecureField;
|
||||
|
||||
/// Extends the polynomial to a larger degree bound.
|
||||
/// Used by the [`CirclePoly::extend()`] function.
|
||||
fn extend(poly: &CirclePoly<Self>, log_size: u32) -> CirclePoly<Self>;
|
||||
|
||||
/// Evaluates the polynomial at all points in the domain.
|
||||
/// Used by the [`CirclePoly::evaluate()`] function.
|
||||
fn evaluate(
|
||||
poly: &CirclePoly<Self>,
|
||||
domain: CircleDomain,
|
||||
twiddles: &TwiddleTree<Self>,
|
||||
) -> CircleEvaluation<Self, BaseField, BitReversedOrder>;
|
||||
|
||||
/// Precomputes twiddles for a given coset.
|
||||
fn precompute_twiddles(coset: Coset) -> TwiddleTree<Self>;
|
||||
}
|
||||
118
Stwo_wrapper/crates/prover/src/core/poly/circle/poly.rs
Normal file
118
Stwo_wrapper/crates/prover/src/core/poly/circle/poly.rs
Normal file
@ -0,0 +1,118 @@
|
||||
use super::{CircleDomain, CircleEvaluation, PolyOps};
|
||||
use crate::core::backend::{Col, Column};
|
||||
use crate::core::circle::CirclePoint;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::fields::FieldOps;
|
||||
use crate::core::poly::twiddles::TwiddleTree;
|
||||
use crate::core::poly::BitReversedOrder;
|
||||
|
||||
/// A polynomial defined on a [CircleDomain].
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CirclePoly<B: FieldOps<BaseField>> {
|
||||
/// Coefficients of the polynomial in the FFT basis.
|
||||
/// Note: These are not the coefficients of the polynomial in the standard
|
||||
/// monomial basis. The FFT basis is a tensor product of the twiddles:
|
||||
/// y, x, pi(x), pi^2(x), ..., pi^{log_size-2}(x).
|
||||
/// pi(x) := 2x^2 - 1.
|
||||
pub coeffs: Col<B, BaseField>,
|
||||
/// The number of coefficients stored as `log2(len(coeffs))`.
|
||||
log_size: u32,
|
||||
}
|
||||
|
||||
impl<B: PolyOps> CirclePoly<B> {
|
||||
/// Creates a new circle polynomial.
|
||||
///
|
||||
/// Coefficients must be in the circle IFFT algorithm's basis stored in bit-reversed order.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the number of coefficients isn't a power of two.
|
||||
pub fn new(coeffs: Col<B, BaseField>) -> Self {
|
||||
assert!(coeffs.len().is_power_of_two());
|
||||
let log_size = coeffs.len().ilog2();
|
||||
Self { log_size, coeffs }
|
||||
}
|
||||
|
||||
pub fn log_size(&self) -> u32 {
|
||||
self.log_size
|
||||
}
|
||||
|
||||
/// Evaluates the polynomial at a single point.
|
||||
pub fn eval_at_point(&self, point: CirclePoint<SecureField>) -> SecureField {
|
||||
B::eval_at_point(self, point)
|
||||
}
|
||||
|
||||
/// Extends the polynomial to a larger degree bound.
|
||||
pub fn extend(&self, log_size: u32) -> Self {
|
||||
B::extend(self, log_size)
|
||||
}
|
||||
|
||||
/// Evaluates the polynomial at all points in the domain.
|
||||
pub fn evaluate(
|
||||
&self,
|
||||
domain: CircleDomain,
|
||||
) -> CircleEvaluation<B, BaseField, BitReversedOrder> {
|
||||
B::evaluate(self, domain, &B::precompute_twiddles(domain.half_coset))
|
||||
}
|
||||
|
||||
/// Evaluates the polynomial at all points in the domain, using precomputed twiddles.
|
||||
pub fn evaluate_with_twiddles(
|
||||
&self,
|
||||
domain: CircleDomain,
|
||||
twiddles: &TwiddleTree<B>,
|
||||
) -> CircleEvaluation<B, BaseField, BitReversedOrder> {
|
||||
B::evaluate(self, domain, twiddles)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl crate::core::backend::cpu::CpuCirclePoly {
|
||||
pub fn is_in_fft_space(&self, log_fft_size: u32) -> bool {
|
||||
use num_traits::Zero;
|
||||
|
||||
let mut coeffs = self.coeffs.clone();
|
||||
while coeffs.last() == Some(&BaseField::zero()) {
|
||||
coeffs.pop();
|
||||
}
|
||||
|
||||
// The highest degree monomial in a fft-space polynomial is x^{(n/2) - 1}y.
|
||||
// And it is at offset (n-1). x^{(n/2)} is at offset `n`, and is not allowed.
|
||||
let highest_degree_allowed_monomial_offset = 1 << log_fft_size;
|
||||
coeffs.len() <= highest_degree_allowed_monomial_offset
|
||||
}
|
||||
|
||||
/// Fri space is the space of polynomials of total degree n/2.
|
||||
/// Highest degree monomials are x^{n/2} and x^{(n/2)-1}y.
|
||||
pub fn is_in_fri_space(&self, log_fft_size: u32) -> bool {
|
||||
use num_traits::Zero;
|
||||
|
||||
let mut coeffs = self.coeffs.clone();
|
||||
while coeffs.last() == Some(&BaseField::zero()) {
|
||||
coeffs.pop();
|
||||
}
|
||||
|
||||
// x^{n/2} is at offset `n`, and is the last offset allowed to be non-zero.
|
||||
let highest_degree_monomial_offset = (1 << log_fft_size) + 1;
|
||||
coeffs.len() <= highest_degree_monomial_offset
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::core::backend::cpu::CpuCirclePoly;
|
||||
use crate::core::circle::CirclePoint;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
|
||||
#[test]
|
||||
fn test_circle_poly_extend() {
|
||||
let poly = CpuCirclePoly::new((0..16).map(BaseField::from_u32_unchecked).collect());
|
||||
let extended = poly.clone().extend(8);
|
||||
let random_point = CirclePoint::get_point(21903);
|
||||
|
||||
assert_eq!(
|
||||
poly.eval_at_point(random_point),
|
||||
extended.eval_at_point(random_point)
|
||||
);
|
||||
}
|
||||
}
|
||||
118
Stwo_wrapper/crates/prover/src/core/poly/circle/secure_poly.rs
Normal file
118
Stwo_wrapper/crates/prover/src/core/poly/circle/secure_poly.rs
Normal file
@ -0,0 +1,118 @@
|
||||
use std::marker::PhantomData;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
|
||||
use super::{CircleDomain, CircleEvaluation, CirclePoly, PolyOps};
|
||||
use crate::core::backend::CpuBackend;
|
||||
use crate::core::circle::CirclePoint;
|
||||
use crate::core::fields::m31::BaseField;
|
||||
use crate::core::fields::qm31::SecureField;
|
||||
use crate::core::fields::secure_column::{SecureColumnByCoords, SECURE_EXTENSION_DEGREE};
|
||||
use crate::core::fields::FieldOps;
|
||||
use crate::core::poly::twiddles::TwiddleTree;
|
||||
use crate::core::poly::BitReversedOrder;
|
||||
|
||||
pub struct SecureCirclePoly<B: FieldOps<BaseField>>(pub [CirclePoly<B>; SECURE_EXTENSION_DEGREE]);
|
||||
|
||||
impl<B: PolyOps> SecureCirclePoly<B> {
|
||||
pub fn eval_at_point(&self, point: CirclePoint<SecureField>) -> SecureField {
|
||||
SecureField::from_partial_evals(self.eval_columns_at_point(point))
|
||||
}
|
||||
|
||||
pub fn eval_columns_at_point(
|
||||
&self,
|
||||
point: CirclePoint<SecureField>,
|
||||
) -> [SecureField; SECURE_EXTENSION_DEGREE] {
|
||||
[
|
||||
self[0].eval_at_point(point),
|
||||
self[1].eval_at_point(point),
|
||||
self[2].eval_at_point(point),
|
||||
self[3].eval_at_point(point),
|
||||
]
|
||||
}
|
||||
|
||||
pub fn log_size(&self) -> u32 {
|
||||
self[0].log_size()
|
||||
}
|
||||
|
||||
pub fn evaluate_with_twiddles(
|
||||
&self,
|
||||
domain: CircleDomain,
|
||||
twiddles: &TwiddleTree<B>,
|
||||
) -> SecureEvaluation<B, BitReversedOrder> {
|
||||
let polys = self.0.each_ref();
|
||||
let columns = polys.map(|poly| poly.evaluate_with_twiddles(domain, twiddles).values);
|
||||
SecureEvaluation::new(domain, SecureColumnByCoords { columns })
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: FieldOps<BaseField>> Deref for SecureCirclePoly<B> {
|
||||
type Target = [CirclePoly<B>; SECURE_EXTENSION_DEGREE];
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
/// A [`SecureField`] evaluation defined on a [CircleDomain].
|
||||
///
|
||||
/// The evaluation is stored as a column major array of [`SECURE_EXTENSION_DEGREE`] many base field
|
||||
/// evaluations. The evaluations are ordered according to the [CircleDomain] ordering.
|
||||
#[derive(Clone)]
|
||||
pub struct SecureEvaluation<B: FieldOps<BaseField>, EvalOrder> {
|
||||
pub domain: CircleDomain,
|
||||
pub values: SecureColumnByCoords<B>,
|
||||
_eval_order: PhantomData<EvalOrder>,
|
||||
}
|
||||
|
||||
impl<B: FieldOps<BaseField>, EvalOrder> SecureEvaluation<B, EvalOrder> {
|
||||
pub fn new(domain: CircleDomain, values: SecureColumnByCoords<B>) -> Self {
|
||||
assert_eq!(domain.size(), values.len());
|
||||
Self {
|
||||
domain,
|
||||
values,
|
||||
_eval_order: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_coordinate_evals(
|
||||
self,
|
||||
) -> [CircleEvaluation<B, BaseField, EvalOrder>; SECURE_EXTENSION_DEGREE] {
|
||||
let Self { domain, values, .. } = self;
|
||||
values.columns.map(|c| CircleEvaluation::new(domain, c))
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: FieldOps<BaseField>, EvalOrder> Deref for SecureEvaluation<B, EvalOrder> {
|
||||
type Target = SecureColumnByCoords<B>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.values
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: FieldOps<BaseField>, EvalOrder> DerefMut for SecureEvaluation<B, EvalOrder> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.values
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: PolyOps> SecureEvaluation<B, BitReversedOrder> {
|
||||
/// Computes a minimal [`SecureCirclePoly`] that evaluates to the same values as this
|
||||
/// evaluation, using precomputed twiddles.
|
||||
pub fn interpolate_with_twiddles(self, twiddles: &TwiddleTree<B>) -> SecureCirclePoly<B> {
|
||||
let domain = self.domain;
|
||||
let cols = self.values.columns;
|
||||
SecureCirclePoly(cols.map(|c| {
|
||||
CircleEvaluation::<B, BaseField, BitReversedOrder>::new(domain, c)
|
||||
.interpolate_with_twiddles(twiddles)
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl<EvalOrder> From<CircleEvaluation<CpuBackend, SecureField, EvalOrder>>
|
||||
for SecureEvaluation<CpuBackend, EvalOrder>
|
||||
{
|
||||
fn from(evaluation: CircleEvaluation<CpuBackend, SecureField, EvalOrder>) -> Self {
|
||||
Self::new(evaluation.domain, evaluation.values.into_iter().collect())
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user