Merge pull request #56 from logos-co/origin/Circom_PoL

Merge the two circom branches
This commit is contained in:
thomaslavaur 2025-04-02 08:45:38 +02:00 committed by GitHub
commit 6299aaa843
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
200 changed files with 30504 additions and 5416 deletions

22
Stwo_wrapper/Cargo.toml Normal file
View File

@ -0,0 +1,22 @@
[workspace]
members = ["crates/prover"]
resolver = "2"
[workspace.package]
version = "0.1.1"
edition = "2021"
[workspace.dependencies]
blake2 = "0.10.6"
blake3 = "1.5.0"
educe = "0.5.0"
hex = "0.4.3"
itertools = "0.12.0"
num-traits = "0.2.17"
thiserror = "1.0.56"
bytemuck = "1.14.3"
tracing = "0.1.40"
[profile.bench]
codegen-units = 1
lto = true

201
Stwo_wrapper/LICENSE Normal file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2024 StarkWare Industries Ltd.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

62
Stwo_wrapper/README.md Normal file
View File

@ -0,0 +1,62 @@
<div align="center">
![STWO](resources/img/logo.png)
<a href="https://github.com/starkware-libs/stwo/actions/workflows/ci.yaml"><img alt="GitHub Workflow Status (with event)" src="https://img.shields.io/github/actions/workflow/status/starkware-libs/stwo/ci.yaml?style=for-the-badge" height=30></a>
<a href="https://codecov.io/gh/starkware-libs/stwo" >
<img src="https://img.shields.io/codecov/c/github/starkware-libs/stwo?style=for-the-badge&logo=codecov" height=30/>
</a>
<a href="https://github.com/starkware-libs/stwo/blob/main/LICENSE"><img src="https://img.shields.io/github/license/starkware-libs/stwo.svg?style=for-the-badge" alt="Project license" height="30"></a>
<a href="https://starkware.co/"><img src="https://img.shields.io/badge/By StarkWare-29296E.svg?&style=for-the-badge&logo=" alt="StarkWare" height="30"></a>
</div>
<div align="center">
<h3>
<a href="https://eprint.iacr.org/2024/278">
Paper
</a>
<span> | </span>
<a href="https://github.com/starkware-libs/stwo">
Documentation
</a>
<span> | </span>
<a href="https://starkware-libs.github.io/stwo/dev/bench/index.html">
Benchmarks
</a>
</h3>
</div>
# Stwo
## 🌟 About
Stwo is a next generation implementation of a [CSTARK](https://eprint.iacr.org/2024/278) prover and verifier, written in Rust 🦀.
> **Stwo is a work in progress.**
>
> It is not recommended to use it in a production setting yet.
## 🚀 Key Features
- **Circle STARKs:** Based on the latest cryptographic research and innovations in the ZK field.
- **High performance:** Stwo is designed to be extremely fast and efficient.
- **Flexible:** Adaptable for various validity proof applications.
## 📊 Benchmarks
Run `poseidon_benchmark.sh` to run a single-threaded poseidon2 hash proof benchmark.
Further benchmarks can be run using `cargo bench`.
Visual representation of benchmarks can be found [here](https://starkware-libs.github.io/stwo/dev/bench/index.html).
## 📜 License
This project is licensed under the **Apache 2.0 license**.
See [LICENSE](LICENSE) for more information.
<!-- markdownlint-restore -->
<!-- prettier-ignore-end -->
<!-- ALL-CONTRIBUTORS-LIST:END -->

0
Stwo_wrapper/WORKSPACE Normal file
View File

View File

@ -0,0 +1,110 @@
[package]
name = "stwo-prover"
version.workspace = true
edition.workspace = true
[features]
parallel = ["rayon"]
slow-tests = []
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
blake2.workspace = true
blake3.workspace = true
bytemuck = { workspace = true, features = ["derive", "extern_crate_alloc"] }
cfg-if = "1.0.0"
downcast-rs = "1.2"
educe.workspace = true
hex.workspace = true
itertools.workspace = true
num-traits.workspace = true
rand = { version = "0.8.5", default-features = false, features = ["small_rng"] }
starknet-crypto = "0.6.2"
starknet-ff = "0.3.7"
ark-bls12-381 = "0.4.0"
ark-ff = "0.4.0"
thiserror.workspace = true
tracing.workspace = true
rayon = { version = "1.10.0", optional = true }
serde = { version = "1.0", features = ["derive"] }
crypto-bigint = "0.5.5"
ark-serialize = "0.4.0"
serde_json = "1.0.116"
[dev-dependencies]
aligned = "0.4.2"
test-log = { version = "0.2.15", features = ["trace"] }
tracing-subscriber = "0.3.18"
[target.'cfg(all(target_family = "wasm", not(target_os = "wasi")))'.dev-dependencies]
wasm-bindgen-test = "0.3.43"
[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies.criterion]
features = ["html_reports"]
version = "0.5.1"
# Default features cause compile error:
# "Rayon cannot be used when targeting wasi32. Try disabling default features."
[target.'cfg(target_arch = "wasm32")'.dev-dependencies.criterion]
default-features = false
features = ["html_reports"]
version = "0.5.1"
[lib]
bench = false
crate-type = ["cdylib", "lib"]
[lints.rust]
warnings = "deny"
future-incompatible = "deny"
nonstandard-style = "deny"
rust-2018-idioms = "deny"
unused = "deny"
[[bench]]
harness = false
name = "bit_rev"
[[bench]]
harness = false
name = "eval_at_point"
[[bench]]
harness = false
name = "fft"
[[bench]]
harness = false
name = "field"
[[bench]]
harness = false
name = "fri"
[[bench]]
harness = false
name = "lookups"
[[bench]]
harness = false
name = "matrix"
[[bench]]
harness = false
name = "merkle"
[[bench]]
harness = false
name = "poseidon"
[[bench]]
harness = false
name = "prefix_sum"
[[bench]]
harness = false
name = "quotients"
[[bench]]
harness = false
name = "pcs"

View File

@ -0,0 +1,2 @@
dev benchmark results can be seen at
https://starkware-libs.github.io/stwo/dev/bench/index.html

View File

@ -0,0 +1,39 @@
#![feature(iter_array_chunks)]
use criterion::{criterion_group, criterion_main, BatchSize, Criterion};
use itertools::Itertools;
use stwo_prover::core::fields::m31::BaseField;
pub fn cpu_bit_rev(c: &mut Criterion) {
use stwo_prover::core::utils::bit_reverse;
// TODO(andrew): Consider using same size for all.
const SIZE: usize = 1 << 24;
let data = (0..SIZE).map(BaseField::from).collect_vec();
c.bench_function("cpu bit_rev 24bit", |b| {
b.iter_batched(
|| data.clone(),
|mut data| bit_reverse(&mut data),
BatchSize::LargeInput,
);
});
}
pub fn simd_bit_rev(c: &mut Criterion) {
use stwo_prover::core::backend::simd::bit_reverse::bit_reverse_m31;
use stwo_prover::core::backend::simd::column::BaseColumn;
const SIZE: usize = 1 << 26;
let data = (0..SIZE).map(BaseField::from).collect::<BaseColumn>();
c.bench_function("simd bit_rev 26bit", |b| {
b.iter_batched(
|| data.data.clone(),
|mut data| bit_reverse_m31(&mut data),
BatchSize::LargeInput,
);
});
}
criterion_group!(
name = bit_rev;
config = Criterion::default().sample_size(10);
targets = simd_bit_rev, cpu_bit_rev);
criterion_main!(bit_rev);

View File

@ -0,0 +1,35 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use stwo_prover::core::backend::cpu::CpuBackend;
use stwo_prover::core::backend::simd::SimdBackend;
use stwo_prover::core::circle::CirclePoint;
use stwo_prover::core::fields::m31::BaseField;
use stwo_prover::core::poly::circle::{CirclePoly, PolyOps};
const LOG_SIZE: u32 = 20;
fn bench_eval_at_secure_point<B: PolyOps>(c: &mut Criterion, id: &str) {
let poly = CirclePoly::new((0..1 << LOG_SIZE).map(BaseField::from).collect());
let mut rng = SmallRng::seed_from_u64(0);
let x = rng.gen();
let y = rng.gen();
let point = CirclePoint { x, y };
c.bench_function(
&format!("{id} eval_at_secure_field_point 2^{LOG_SIZE}"),
|b| {
b.iter(|| B::eval_at_point(black_box(&poly), black_box(point)));
},
);
}
fn eval_at_secure_point_benches(c: &mut Criterion) {
bench_eval_at_secure_point::<SimdBackend>(c, "simd");
bench_eval_at_secure_point::<CpuBackend>(c, "cpu");
}
criterion_group!(
name = benches;
config = Criterion::default().sample_size(10);
targets = eval_at_secure_point_benches);
criterion_main!(benches);

View File

@ -0,0 +1,131 @@
#![feature(iter_array_chunks)]
use std::hint::black_box;
use std::mem::{size_of_val, transmute};
use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion, Throughput};
use itertools::Itertools;
use stwo_prover::core::backend::simd::column::BaseColumn;
use stwo_prover::core::backend::simd::fft::ifft::{
get_itwiddle_dbls, ifft, ifft3_loop, ifft_vecwise_loop,
};
use stwo_prover::core::backend::simd::fft::rfft::{fft, get_twiddle_dbls};
use stwo_prover::core::backend::simd::fft::transpose_vecs;
use stwo_prover::core::backend::simd::m31::PackedBaseField;
use stwo_prover::core::fields::m31::BaseField;
use stwo_prover::core::poly::circle::CanonicCoset;
pub fn simd_ifft(c: &mut Criterion) {
let mut group = c.benchmark_group("iffts");
for log_size in 16..=28 {
let domain = CanonicCoset::new(log_size).circle_domain();
let twiddle_dbls = get_itwiddle_dbls(domain.half_coset);
let twiddle_dbls_refs = twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec();
let values: BaseColumn = (0..domain.size()).map(BaseField::from).collect();
group.throughput(Throughput::Bytes(size_of_val(&*values.data) as u64));
group.bench_function(BenchmarkId::new("simd ifft", log_size), |b| {
b.iter_batched(
|| values.clone().data,
|mut data| unsafe {
ifft(
transmute(data.as_mut_ptr()),
black_box(&twiddle_dbls_refs),
black_box(log_size as usize),
);
},
BatchSize::LargeInput,
)
});
}
}
pub fn simd_ifft_parts(c: &mut Criterion) {
const LOG_SIZE: u32 = 14;
let domain = CanonicCoset::new(LOG_SIZE).circle_domain();
let twiddle_dbls = get_itwiddle_dbls(domain.half_coset);
let twiddle_dbls_refs = twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec();
let values: BaseColumn = (0..domain.size()).map(BaseField::from).collect();
let mut group = c.benchmark_group("ifft parts");
// Note: These benchmarks run only on 2^LOG_SIZE elements because of their parameters.
// Increasing the figure above won't change the runtime of these benchmarks.
group.throughput(Throughput::Bytes(4 << LOG_SIZE));
group.bench_function(format!("simd ifft_vecwise_loop 2^{LOG_SIZE}"), |b| {
b.iter_batched(
|| values.clone().data,
|mut values| unsafe {
ifft_vecwise_loop(
transmute(values.as_mut_ptr()),
black_box(&twiddle_dbls_refs),
black_box(9),
black_box(0),
)
},
BatchSize::LargeInput,
);
});
group.bench_function(format!("simd ifft3_loop 2^{LOG_SIZE}"), |b| {
b.iter_batched(
|| values.clone().data,
|mut values| unsafe {
ifft3_loop(
transmute(values.as_mut_ptr()),
black_box(&twiddle_dbls_refs[3..]),
black_box(7),
black_box(4),
black_box(0),
)
},
BatchSize::LargeInput,
);
});
const TRANSPOSE_LOG_SIZE: u32 = 20;
let transpose_values: BaseColumn = (0..1 << TRANSPOSE_LOG_SIZE).map(BaseField::from).collect();
group.throughput(Throughput::Bytes(4 << TRANSPOSE_LOG_SIZE));
group.bench_function(format!("simd transpose_vecs 2^{TRANSPOSE_LOG_SIZE}"), |b| {
b.iter_batched(
|| transpose_values.clone().data,
|mut values| unsafe {
transpose_vecs(
transmute(values.as_mut_ptr()),
black_box(TRANSPOSE_LOG_SIZE as usize - 4),
)
},
BatchSize::LargeInput,
);
});
}
pub fn simd_rfft(c: &mut Criterion) {
const LOG_SIZE: u32 = 20;
let domain = CanonicCoset::new(LOG_SIZE).circle_domain();
let twiddle_dbls = get_twiddle_dbls(domain.half_coset);
let twiddle_dbls_refs = twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec();
let values: BaseColumn = (0..domain.size()).map(BaseField::from).collect();
c.bench_function("simd rfft 20bit", |b| {
b.iter_with_large_drop(|| unsafe {
let mut target = Vec::<PackedBaseField>::with_capacity(values.data.len());
#[allow(clippy::uninit_vec)]
target.set_len(values.data.len());
fft(
black_box(transmute(values.data.as_ptr())),
transmute(target.as_mut_ptr()),
black_box(&twiddle_dbls_refs),
black_box(LOG_SIZE as usize),
)
});
});
}
criterion_group!(
name = benches;
config = Criterion::default().sample_size(10);
targets = simd_ifft, simd_ifft_parts, simd_rfft);
criterion_main!(benches);

View File

@ -0,0 +1,150 @@
use criterion::{criterion_group, criterion_main, Criterion};
use num_traits::One;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use stwo_prover::core::backend::simd::m31::{PackedBaseField, N_LANES};
use stwo_prover::core::fields::cm31::CM31;
use stwo_prover::core::fields::m31::{BaseField, M31};
use stwo_prover::core::fields::qm31::SecureField;
pub const N_ELEMENTS: usize = 1 << 16;
pub const N_STATE_ELEMENTS: usize = 8;
pub fn m31_operations_bench(c: &mut Criterion) {
let mut rng = SmallRng::seed_from_u64(0);
let elements: Vec<M31> = (0..N_ELEMENTS).map(|_| rng.gen()).collect();
let mut state: [M31; N_STATE_ELEMENTS] = rng.gen();
c.bench_function("M31 mul", |b| {
b.iter(|| {
for elem in &elements {
for _ in 0..128 {
for state_elem in &mut state {
*state_elem *= *elem;
}
}
}
})
});
c.bench_function("M31 add", |b| {
b.iter(|| {
for elem in &elements {
for _ in 0..128 {
for state_elem in &mut state {
*state_elem += *elem;
}
}
}
})
});
}
pub fn cm31_operations_bench(c: &mut Criterion) {
let mut rng = SmallRng::seed_from_u64(0);
let elements: Vec<CM31> = (0..N_ELEMENTS).map(|_| rng.gen()).collect();
let mut state: [CM31; N_STATE_ELEMENTS] = rng.gen();
c.bench_function("CM31 mul", |b| {
b.iter(|| {
for elem in &elements {
for _ in 0..128 {
for state_elem in &mut state {
*state_elem *= *elem;
}
}
}
})
});
c.bench_function("CM31 add", |b| {
b.iter(|| {
for elem in &elements {
for _ in 0..128 {
for state_elem in &mut state {
*state_elem += *elem;
}
}
}
})
});
}
pub fn qm31_operations_bench(c: &mut Criterion) {
let mut rng = SmallRng::seed_from_u64(0);
let elements: Vec<SecureField> = (0..N_ELEMENTS).map(|_| rng.gen()).collect();
let mut state: [SecureField; N_STATE_ELEMENTS] = rng.gen();
c.bench_function("SecureField mul", |b| {
b.iter(|| {
for elem in &elements {
for _ in 0..128 {
for state_elem in &mut state {
*state_elem *= *elem;
}
}
}
})
});
c.bench_function("SecureField add", |b| {
b.iter(|| {
for elem in &elements {
for _ in 0..128 {
for state_elem in &mut state {
*state_elem += *elem;
}
}
}
})
});
}
pub fn simd_m31_operations_bench(c: &mut Criterion) {
let mut rng = SmallRng::seed_from_u64(0);
let elements: Vec<PackedBaseField> = (0..N_ELEMENTS / N_LANES).map(|_| rng.gen()).collect();
let mut states = vec![PackedBaseField::broadcast(BaseField::one()); N_STATE_ELEMENTS];
c.bench_function("mul_simd", |b| {
b.iter(|| {
for elem in elements.iter() {
for _ in 0..128 {
for state in states.iter_mut() {
*state *= *elem;
}
}
}
})
});
c.bench_function("add_simd", |b| {
b.iter(|| {
for elem in elements.iter() {
for _ in 0..128 {
for state in states.iter_mut() {
*state += *elem;
}
}
}
})
});
c.bench_function("sub_simd", |b| {
b.iter(|| {
for elem in elements.iter() {
for _ in 0..128 {
for state in states.iter_mut() {
*state -= *elem;
}
}
}
})
});
}
criterion_group!(
name = benches;
config = Criterion::default().sample_size(10);
targets = m31_operations_bench, cm31_operations_bench, qm31_operations_bench,
simd_m31_operations_bench);
criterion_main!(benches);

View File

@ -0,0 +1,35 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use stwo_prover::core::backend::CpuBackend;
use stwo_prover::core::fields::m31::BaseField;
use stwo_prover::core::fields::qm31::SecureField;
use stwo_prover::core::fields::secure_column::SecureColumnByCoords;
use stwo_prover::core::fri::FriOps;
use stwo_prover::core::poly::circle::{CanonicCoset, PolyOps};
use stwo_prover::core::poly::line::{LineDomain, LineEvaluation};
fn folding_benchmark(c: &mut Criterion) {
const LOG_SIZE: u32 = 12;
let domain = LineDomain::new(CanonicCoset::new(LOG_SIZE + 1).half_coset());
let evals = LineEvaluation::new(
domain,
SecureColumnByCoords {
columns: std::array::from_fn(|i| {
vec![BaseField::from_u32_unchecked(i as u32); 1 << LOG_SIZE]
}),
},
);
let alpha = SecureField::from_u32_unchecked(2213980, 2213981, 2213982, 2213983);
let twiddles = CpuBackend::precompute_twiddles(domain.coset());
c.bench_function("fold_line", |b| {
b.iter(|| {
black_box(CpuBackend::fold_line(
black_box(&evals),
black_box(alpha),
&twiddles,
));
})
});
}
criterion_group!(benches, folding_benchmark);
criterion_main!(benches);

View File

@ -0,0 +1,104 @@
use criterion::{criterion_group, criterion_main, BatchSize, Criterion};
use rand::distributions::{Distribution, Standard};
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use stwo_prover::core::backend::simd::SimdBackend;
use stwo_prover::core::backend::CpuBackend;
use stwo_prover::core::channel::Blake2sChannel;
use stwo_prover::core::fields::Field;
use stwo_prover::core::lookups::gkr_prover::{prove_batch, GkrOps, Layer};
use stwo_prover::core::lookups::mle::{Mle, MleOps};
const LOG_N_ROWS: u32 = 16;
fn bench_gkr_grand_product<B: GkrOps>(c: &mut Criterion, id: &str) {
let mut rng = SmallRng::seed_from_u64(0);
let layer = Layer::<B>::GrandProduct(gen_random_mle(&mut rng, LOG_N_ROWS));
c.bench_function(&format!("{id} grand product lookup 2^{LOG_N_ROWS}"), |b| {
b.iter_batched(
|| layer.clone(),
|layer| prove_batch(&mut Blake2sChannel::default(), vec![layer]),
BatchSize::LargeInput,
)
});
c.bench_function(
&format!("{id} grand product lookup batch 4x 2^{LOG_N_ROWS}"),
|b| {
b.iter_batched(
|| vec![layer.clone(), layer.clone(), layer.clone(), layer.clone()],
|layers| prove_batch(&mut Blake2sChannel::default(), layers),
BatchSize::LargeInput,
)
},
);
}
fn bench_gkr_logup_generic<B: GkrOps>(c: &mut Criterion, id: &str) {
let mut rng = SmallRng::seed_from_u64(0);
let generic_layer = Layer::<B>::LogUpGeneric {
numerators: gen_random_mle(&mut rng, LOG_N_ROWS),
denominators: gen_random_mle(&mut rng, LOG_N_ROWS),
};
c.bench_function(&format!("{id} generic logup lookup 2^{LOG_N_ROWS}"), |b| {
b.iter_batched(
|| generic_layer.clone(),
|layer| prove_batch(&mut Blake2sChannel::default(), vec![layer]),
BatchSize::LargeInput,
)
});
}
fn bench_gkr_logup_multiplicities<B: GkrOps>(c: &mut Criterion, id: &str) {
let mut rng = SmallRng::seed_from_u64(0);
let multiplicities_layer = Layer::<B>::LogUpMultiplicities {
numerators: gen_random_mle(&mut rng, LOG_N_ROWS),
denominators: gen_random_mle(&mut rng, LOG_N_ROWS),
};
c.bench_function(
&format!("{id} multiplicities logup lookup 2^{LOG_N_ROWS}"),
|b| {
b.iter_batched(
|| multiplicities_layer.clone(),
|layer| prove_batch(&mut Blake2sChannel::default(), vec![layer]),
BatchSize::LargeInput,
)
},
);
}
fn bench_gkr_logup_singles<B: GkrOps>(c: &mut Criterion, id: &str) {
let mut rng = SmallRng::seed_from_u64(0);
let singles_layer = Layer::<B>::LogUpSingles {
denominators: gen_random_mle(&mut rng, LOG_N_ROWS),
};
c.bench_function(&format!("{id} singles logup lookup 2^{LOG_N_ROWS}"), |b| {
b.iter_batched(
|| singles_layer.clone(),
|layer| prove_batch(&mut Blake2sChannel::default(), vec![layer]),
BatchSize::LargeInput,
)
});
}
/// Generates a random multilinear polynomial.
fn gen_random_mle<B: MleOps<F>, F: Field>(rng: &mut impl Rng, n_variables: u32) -> Mle<B, F>
where
Standard: Distribution<F>,
{
Mle::new((0..1 << n_variables).map(|_| rng.gen()).collect())
}
fn gkr_lookup_benches(c: &mut Criterion) {
bench_gkr_grand_product::<SimdBackend>(c, "simd");
bench_gkr_logup_generic::<SimdBackend>(c, "simd");
bench_gkr_logup_multiplicities::<SimdBackend>(c, "simd");
bench_gkr_logup_singles::<SimdBackend>(c, "simd");
bench_gkr_grand_product::<CpuBackend>(c, "cpu");
bench_gkr_logup_generic::<CpuBackend>(c, "cpu");
bench_gkr_logup_multiplicities::<CpuBackend>(c, "cpu");
bench_gkr_logup_singles::<CpuBackend>(c, "cpu");
}
criterion_group!(benches, gkr_lookup_benches);
criterion_main!(benches);

View File

@ -0,0 +1,63 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use stwo_prover::core::fields::m31::{M31, P};
use stwo_prover::core::fields::qm31::QM31;
use stwo_prover::math::matrix::{RowMajorMatrix, SquareMatrix};
const MATRIX_SIZE: usize = 24;
const QM31_MATRIX_SIZE: usize = 6;
// TODO(ShaharS): Share code with other benchmarks.
fn row_major_matrix_multiplication_bench(c: &mut Criterion) {
let mut rng = SmallRng::seed_from_u64(0);
let matrix_m31 = RowMajorMatrix::<M31, MATRIX_SIZE>::new(
(0..MATRIX_SIZE.pow(2))
.map(|_| rng.gen())
.collect::<Vec<M31>>(),
);
let matrix_qm31 = RowMajorMatrix::<QM31, QM31_MATRIX_SIZE>::new(
(0..QM31_MATRIX_SIZE.pow(2))
.map(|_| rng.gen())
.collect::<Vec<QM31>>(),
);
// Create vector M31.
let vec: [M31; MATRIX_SIZE] = rng.gen();
// Create vector QM31.
let vec_qm31: [QM31; QM31_MATRIX_SIZE] = [(); QM31_MATRIX_SIZE].map(|_| {
QM31::from_u32_unchecked(
rng.gen::<u32>() % P,
rng.gen::<u32>() % P,
rng.gen::<u32>() % P,
rng.gen::<u32>() % P,
)
});
// bench matrix multiplication.
c.bench_function(
&format!("RowMajorMatrix M31 {size}x{size} mul", size = MATRIX_SIZE),
|b| {
b.iter(|| {
black_box(matrix_m31.mul(vec));
})
},
);
c.bench_function(
&format!(
"QM31 RowMajorMatrix {size}x{size} mul",
size = QM31_MATRIX_SIZE
),
|b| {
b.iter(|| {
black_box(matrix_qm31.mul(vec_qm31));
})
},
);
}
criterion_group!(benches, row_major_matrix_multiplication_bench);
criterion_main!(benches);

View File

@ -0,0 +1,38 @@
#![feature(iter_array_chunks)]
use criterion::{criterion_group, criterion_main, Criterion, Throughput};
use itertools::Itertools;
use num_traits::Zero;
use stwo_prover::core::backend::simd::SimdBackend;
use stwo_prover::core::backend::{Col, CpuBackend};
use stwo_prover::core::fields::m31::{BaseField, N_BYTES_FELT};
use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleHasher;
use stwo_prover::core::vcs::ops::MerkleOps;
const LOG_N_ROWS: u32 = 16;
const LOG_N_COLS: u32 = 8;
fn bench_blake2s_merkle<B: MerkleOps<Blake2sMerkleHasher>>(c: &mut Criterion, id: &str) {
let col: Col<B, BaseField> = (0..1 << LOG_N_ROWS).map(|_| BaseField::zero()).collect();
let cols = (0..1 << LOG_N_COLS).map(|_| col.clone()).collect_vec();
let col_refs = cols.iter().collect_vec();
let mut group = c.benchmark_group("merkle throughput");
let n_elements = 1 << (LOG_N_COLS + LOG_N_ROWS);
group.throughput(Throughput::Elements(n_elements));
group.throughput(Throughput::Bytes(N_BYTES_FELT as u64 * n_elements));
group.bench_function(&format!("{id} merkle"), |b| {
b.iter_with_large_drop(|| B::commit_on_layer(LOG_N_ROWS, None, &col_refs))
});
}
fn blake2s_merkle_benches(c: &mut Criterion) {
bench_blake2s_merkle::<SimdBackend>(c, "simd");
bench_blake2s_merkle::<CpuBackend>(c, "cpu");
}
criterion_group!(
name = benches;
config = Criterion::default().sample_size(10);
targets = blake2s_merkle_benches);
criterion_main!(benches);

View File

@ -0,0 +1,81 @@
use std::iter;
use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion};
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use stwo_prover::core::backend::simd::SimdBackend;
use stwo_prover::core::backend::{BackendForChannel, CpuBackend};
use stwo_prover::core::channel::Blake2sChannel;
use stwo_prover::core::fields::m31::BaseField;
use stwo_prover::core::pcs::CommitmentTreeProver;
use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation};
use stwo_prover::core::poly::twiddles::TwiddleTree;
use stwo_prover::core::poly::BitReversedOrder;
use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel;
const LOG_COSET_SIZE: u32 = 20;
const LOG_BLOWUP_FACTOR: u32 = 1;
const N_POLYS: usize = 16;
fn benched_fn<B: BackendForChannel<Blake2sMerkleChannel>>(
evals: Vec<CircleEvaluation<B, BaseField, BitReversedOrder>>,
channel: &mut Blake2sChannel,
twiddles: &TwiddleTree<B>,
) {
let polys = evals
.into_iter()
.map(|eval| eval.interpolate_with_twiddles(twiddles))
.collect();
CommitmentTreeProver::<B, Blake2sMerkleChannel>::new(
polys,
LOG_BLOWUP_FACTOR,
channel,
twiddles,
);
}
fn bench_pcs<B: BackendForChannel<Blake2sMerkleChannel>>(c: &mut Criterion, id: &str) {
let small_domain = CanonicCoset::new(LOG_COSET_SIZE);
let big_domain = CanonicCoset::new(LOG_COSET_SIZE + LOG_BLOWUP_FACTOR);
let twiddles = B::precompute_twiddles(big_domain.half_coset());
let mut channel = Blake2sChannel::default();
let mut rng = SmallRng::seed_from_u64(0);
let evals: Vec<CircleEvaluation<B, BaseField, BitReversedOrder>> = iter::repeat_with(|| {
CircleEvaluation::new(
small_domain.circle_domain(),
(0..1 << LOG_COSET_SIZE).map(|_| rng.gen()).collect(),
)
})
.take(N_POLYS)
.collect();
c.bench_function(
&format!("{id} polynomial commitment 2^{LOG_COSET_SIZE}"),
|b| {
b.iter_batched(
|| evals.clone(),
|evals| {
benched_fn::<B>(
black_box(evals),
black_box(&mut channel),
black_box(&twiddles),
)
},
BatchSize::LargeInput,
);
},
);
}
fn pcs_benches(c: &mut Criterion) {
bench_pcs::<SimdBackend>(c, "simd");
bench_pcs::<CpuBackend>(c, "cpu");
}
criterion_group!(
name = benches;
config = Criterion::default().sample_size(10);
targets = pcs_benches);
criterion_main!(benches);

View File

@ -0,0 +1,18 @@
use criterion::{criterion_group, criterion_main, Criterion, Throughput};
use stwo_prover::core::pcs::PcsConfig;
use stwo_prover::examples::poseidon::prove_poseidon;
pub fn simd_poseidon(c: &mut Criterion) {
const LOG_N_INSTANCES: u32 = 18;
let mut group = c.benchmark_group("poseidon2");
group.throughput(Throughput::Elements(1u64 << LOG_N_INSTANCES));
group.bench_function(format!("poseidon2 2^{} instances", LOG_N_INSTANCES), |b| {
b.iter(|| prove_poseidon(LOG_N_INSTANCES, PcsConfig::default()));
});
}
criterion_group!(
name = bit_rev;
config = Criterion::default().sample_size(10);
targets = simd_poseidon);
criterion_main!(bit_rev);

View File

@ -0,0 +1,19 @@
use criterion::{criterion_group, criterion_main, BatchSize, Criterion};
use stwo_prover::core::backend::simd::column::BaseColumn;
use stwo_prover::core::backend::simd::prefix_sum::inclusive_prefix_sum;
use stwo_prover::core::fields::m31::BaseField;
pub fn simd_prefix_sum_bench(c: &mut Criterion) {
const LOG_SIZE: u32 = 24;
let evals: BaseColumn = (0..1 << LOG_SIZE).map(BaseField::from).collect();
c.bench_function(&format!("simd prefix_sum 2^{LOG_SIZE}"), |b| {
b.iter_batched(
|| evals.clone(),
inclusive_prefix_sum,
BatchSize::LargeInput,
);
});
}
criterion_group!(benches, simd_prefix_sum_bench);
criterion_main!(benches);

View File

@ -0,0 +1,55 @@
#![feature(iter_array_chunks)]
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use itertools::Itertools;
use stwo_prover::core::backend::cpu::CpuBackend;
use stwo_prover::core::backend::simd::SimdBackend;
use stwo_prover::core::circle::SECURE_FIELD_CIRCLE_GEN;
use stwo_prover::core::fields::m31::BaseField;
use stwo_prover::core::fields::qm31::SecureField;
use stwo_prover::core::pcs::quotients::{ColumnSampleBatch, QuotientOps};
use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation};
use stwo_prover::core::poly::BitReversedOrder;
// TODO(andrew): Consider removing const generics and making all sizes the same.
fn bench_quotients<B: QuotientOps, const LOG_N_ROWS: u32, const LOG_N_COLS: u32>(
c: &mut Criterion,
id: &str,
) {
let domain = CanonicCoset::new(LOG_N_ROWS).circle_domain();
let values = (0..domain.size()).map(BaseField::from).collect();
let col = CircleEvaluation::<B, BaseField, BitReversedOrder>::new(domain, values);
let cols = (0..1 << LOG_N_COLS).map(|_| col.clone()).collect_vec();
let col_refs = cols.iter().collect_vec();
let random_coeff = SecureField::from_u32_unchecked(0, 1, 2, 3);
let a = SecureField::from_u32_unchecked(5, 6, 7, 8);
let samples = vec![ColumnSampleBatch {
point: SECURE_FIELD_CIRCLE_GEN,
columns_and_values: (0..1 << LOG_N_COLS).map(|i| (i, a)).collect(),
}];
c.bench_function(
&format!("{id} quotients 2^{LOG_N_COLS} x 2^{LOG_N_ROWS}"),
|b| {
b.iter_with_large_drop(|| {
B::accumulate_quotients(
black_box(domain),
black_box(&col_refs),
black_box(random_coeff),
black_box(&samples),
1,
)
})
},
);
}
fn quotients_benches(c: &mut Criterion) {
bench_quotients::<SimdBackend, 20, 8>(c, "simd");
bench_quotients::<CpuBackend, 16, 8>(c, "cpu");
}
criterion_group!(
name = benches;
config = Criterion::default().sample_size(10);
targets = quotients_benches);
criterion_main!(benches);

View File

@ -0,0 +1,348 @@
{
"commitments" :
["34328580272026076035687604093297365442785733592720865218001799813393342152908",
"38388381845372648579572899115609862601821983406101214230086519922780265042634"],
"sampled_values_0" :
[["1","0","0","0"],
["2129160320","1109509513","787887008","1676461964"],
["262908602","915488457","1893945291","1774327476"],
["894719153","1570509766","1424186619","204092576"],
["397490811","836398274","1615765624","2013800563"],
["1022303904","276983775","1064742229","165204856"],
["1200363525","170838026","524999776","156116441"],
["850733526","448725560","1521962209","1318190714"],
["1187866075","1705588092","924088348","490002418"],
["2033565088","996780784","1820235518","2048788344"],
["2061590372","1150986157","711772586","1511398564"],
["1066623954","530384603","1890251380","1699008129"],
["734047580","1685768538","505142109","787113212"],
["2030904700","99932423","695391286","1736941035"],
["1580330105","932031717","1705998668","146411959"],
["1585732224","1556242253","941668238","1998570239"],
["199481433","2123320403","1257464748","1663811899"],
["2139019524","1547107722","728449250","1941851166"],
["752079023","268472135","1465850435","16510773"],
["1279312817","63252415","442230579","1560954631"],
["1074859131","137997593","2118329011","652535723"],
["297567647","1483381078","1941495981","599737348"],
["1735543786","1420676479","1354982762","1114211268"],
["1691705401","1143446295","1748115479","1666756627"],
["955696743","2077778309","736065989","1319443838"],
["1076874307","1001483910","1702287354","819727011"],
["1134989244","1823710400","2067694105","1098263343"],
["1793642608","961404475","1279773056","1815400043"],
["739677274","1827877577","838562378","171296720"],
["2036367121","1901888610","289723252","2014426907"],
["330020507","436937516","2113056521","1828501207"],
["1359068814","583899921","734628376","1223217137"],
["1319501520","1242972089","1202216521","1285024997"],
["681182370","1569622309","1574376904","1563950435"],
["1204519566","483612224","1677731115","1667757584"],
["330284364","917877098","57538161","179869993"],
["2056561198","119768893","740294154","1454562198"],
["79009084","545196641","13388962","1973400144"],
["885977898","1973300145","37115619","957100699"],
["1937449867","1777683674","1983002799","757662558"],
["344927561","357845689","26887161","664585634"],
["1462268220","615463524","209500386","44308852"],
["570984705","2022111132","1404632615","2119081660"],
["13183327","1584451280","1216116653","316345540"],
["1497965915","705236857","1892466476","1068567492"],
["1758694676","1408790161","1140545981","315723937"],
["645308461","1125824784","1786470558","1240927727"],
["1213464061","470930291","1718629724","1149088875"],
["214577693","1578610321","2133720991","226291629"],
["1357706729","2097875841","1767996253","1478111500"],
["1154658683","752162439","2018723944","163997560"],
["1051993583","703716977","379706674","487262860"],
["1017692573","2060296775","2001023083","1064213951"],
["1042587725","1701108370","204550428","904590130"],
["1115340870","743420370","1927225111","1276396551"],
["493638626","1874789377","47342513","209203758"],
["1558586505","83459476","247638703","1975504267"],
["2097068784","954319448","367516919","1545761518"],
["1655645294","352838520","1307263981","1110198118"],
["1169856046","1925368371","1362317240","1926032147"],
["1940113709","885624001","1395047654","80053995"],
["1778932990","25092730","201117282","1724571908"],
["2096327738","233411984","1247443120","713989449"],
["808532602","136577890","1015579288","38900716"],
["1182257782","1186245376","1451332036","2080170103"],
["1662610758","1505542080","1038243031","1889715771"],
["440146119","942837214","1440484295","1593949278"],
["46258268","1884246120","164930024","2050584510"],
["1198954868","1079638495","1424072583","1028611344"],
["2112984649","1531382496","1873151714","1818301795"],
["1554382282","253920307","1641628530","1378998084"],
["857898234","686236793","2091871553","184978860"],
["2049153599","6111471","1579475775","32492894"],
["1371356596","679072793","1547377985","354305233"],
["1799882226","1201472049","1592617716","125534957"],
["1277144880","253726080","1800145982","1125162267"],
["1577717920","440984421","1377891036","846453148"],
["1952731919","1710992214","673668053","1871913638"],
["1559011028","2060945859","719954448","1356468891"],
["1961642242","1693473944","1300152522","412222111"],
["861208187","1242659514","977183954","38730935"],
["1016984917","1368361439","2106430139","1225979890"],
["1427754325","1482206106","1465316380","1096279813"],
["566051043","2025874544","234976335","1482256978"],
["1750543495","1494541462","374330732","411642241"],
["230654343","55625728","136463431","1099606808"],
["1172218793","1260458608","1314942990","75527287"],
["1824515276","916178746","1300275105","370626746"],
["915931367","987018043","56193044","617907884"],
["1934695822","1112844637","609268252","1972086910"],
["619631651","152029630","1979976905","292597437"],
["62258350","1890115432","1373605674","1505619938"],
["1770422019","1398189304","1773172351","1576001433"],
["650940868","1756047014","1764798953","1146887875"],
["1746945043","528205234","778346028","1797468521"],
["760802416","1479409742","1556974632","1307498378"],
["102511022","1787975482","968854748","1010240763"],
["330722054","2046294448","14132125","1822414050"],
["943548871","1770900623","1861740461","1290634078"],
["1402661415","1361511065","1784889120","837615360"]],
"sampled_values_1" :
[["712066144","1576368753","626134398","426337436"],
["160634493","1096735733","992622982","964509862"],
["208900621","1128739590","1423579079","1688318061"],
["1029182234","1152361165","571476481","1593867154"]],
"decommitment_0" :
["24311567319749512546399129581715033328970605051392227451685196018312506896509",
"8450134967305372517473027560161707471995673370792264153422077885080332622841",
"6431507699794114682519586182713221908058047520896405293833270087934517909753",
"303109001984349840640377328716025252051982378448629744935456455431709129012",
"47167328465744900593371601186109726758160197572292632388959155138584359581158",
"50584492046778480438774038937088410409133167768957478525289857065775850658491",
"30584798499699103841624545814425958941934653399588880797257122471101102880636",
"33441256878213890325682161124370878299436204406591246133637659120215439522803",
"3288124068330032280185519028600654292250668929588668389702892483946668251740",
"29852774919556057664485671676242264613416836486089146650713214180894511265116",
"12482060975231949385592255321766253365687502822944549845564491620341379321204",
"46285234163162336949700608657672147469543559995399282843606812790099228411758",
"20807128972645591294726020136444795908525656782422245307591812614900798799914"],
"decommitment_1" :
["18063303111481257844109225560025890393366258018933166919604543575686388632162",
"1676364734386980395984608216327451243278421019544108756198322792517099196249",
"5278661052518480850653886996628582549184134231869598116690316714367933376948",
"21983822689977371558234298346357617674436224016274009820764238516520240403273",
"605332543427376153930374757063581881998320956602375739165671986207155079359",
"33771702906041565783498389127165108212044382608172583325407071671862086994048",
"6930451780154275491146135719028766497977496109537963233244808739657647563071",
"4564117668410212714684125903928600765456322272915099403587425756488534507713",
"337808767671877648828499299861821973796749820854854708379479049898835100991",
"2725840354457305623692571800192492803162041315546256970381708693201407812833",
"34716495111790106826563330917176360656701717867702196654990744866499300990003",
"49719870445464463616785535809529171382800153139923763422202182182379572350737",
"48664746464641275030915461677298150155193593333108431337941029583245720868695",
"32557886668297237033601675259512842580727821006475208499640136322794706303894",
"15014835402414421167586357788116276188694467622586221351644991310645286648480",
"12973705814659120511327850727547995427054971555827754777395732787633567627149",
"15379912850866472398958956306527914195058439699787840041152620034933267404138",
"19859070819439084101412868355121176941090844577507824922960697697668791429525",
"38273559034692361632775489704953448699371080776239846995670381153186834620044"],
"queried_values_0" :
[["1","1","1","1","1","1"],
["730457281","730490049","28918683","28885915","1656126010","1656093242"],
["855614122","1238037465","1836504291","355791428","757095818","467806903"],
["674179888","1530445315","1720543014","76190330","1475912409","1017215862"],
["1142290008","1148671853","1619097781","938511401","904357795","257652679"],
["1679234056","1355264641","2139729457","574756654","604307234","1146556949"],
["1000500309","2008905806","1442759180","598876729","1786070690","1072293976"],
["1119085545","2133345582","135683580","216214405","1049766224","943727969"],
["206423262","2047139937","305085364","1422472664","1826554088","1032095092"],
["501238882","1656305868","724710382","1949772461","1426787917","585368894"],
["1005468045","1775577441","1042182076","415631363","1067013227","1635705270"],
["1776076392","216798814","1525036520","1160666510","1212132211","1915058776"],
["859923105","1633989410","182110635","2060185314","1084464822","1129902257"],
["489437802","313401022","271315129","357612175","2050381179","647577687"],
["1495302158","2052264981","1498165299","1164417520","1050104037","450244199"],
["1084986392","398966983","808449145","1733554138","2068501028","659474347"],
["399458768","1789245133","1698759035","188433436","1794535430","364419824"],
["2013965647","722839714","928854328","124488895","1378959529","952886009"],
["1334765706","193402268","471076108","640800921","1998121783","961582406"],
["1067762968","381831281","560459357","1025929344","181659877","1922040224"],
["1993303462","467991218","849673597","744722836","239634354","329631295"],
["785794488","1649178388","672964420","1281255462","900602801","271501809"],
["857859728","1325395820","985014020","1094321795","259553347","774587048"],
["1214640090","1588569866","871717820","1131833706","1625896842","1635087550"],
["796549205","931495223","2018253108","1395065060","158209751","1160478135"],
["883143962","729115354","190207821","839273168","1668931939","2074584689"],
["1490296658","1846956206","1610364850","56422972","160482417","681872093"],
["1270585092","1910190167","464113273","613529242","1027101122","1014185686"],
["1456043179","1999662961","193940913","678382864","39040067","1236859818"],
["1626243617","901735777","1703169024","911300891","1640727682","1121874896"],
["492192896","15672698","319327174","1727120334","1965889437","114404366"],
["407079019","949462637","255390508","1753162095","501134776","1457122467"],
["1478573872","1439193434","1053200675","1001140887","1553935777","1253681552"],
["183135520","946237525","1802924023","1831496784","1893117930","1830486286"],
["234902670","1169030504","196055115","1323151968","855748623","1328842866"],
["1150999776","1338824346","2072101698","774206263","1967350016","1808817867"],
["924341552","1430286424","511268814","825025920","1061850574","1954646566"],
["302634890","314434153","1692670768","1822915313","1244352075","1953834230"],
["1576167467","687837005","2116136752","144109400","1590157548","1634932462"],
["396756275","1272134898","1207308240","818219166","1314182589","109494000"],
["846425160","897737569","757312164","826009489","1019831588","1977463051"],
["2065801114","1918982367","1548689186","2082631803","298112070","383438809"],
["1034102289","461735180","2115581275","1343026598","1229979058","1021418523"],
["1784173874","166635387","547550115","1094693960","573193735","451367040"],
["119818313","659105018","1741377697","8940733","911200334","511474518"],
["1511949880","1315119529","1267019200","2134944693","878254810","375758264"],
["1203050254","156394547","1348568635","412863443","1068659960","1407913814"],
["361779719","130417374","89109096","117994876","1151322919","863143484"],
["1007476533","989566160","138644964","1672742874","540141118","1296408100"],
["1824241144","2051199719","1863718547","2109877864","36689613","1055926854"],
["791693003","1433717239","991140958","1565955371","1839976870","1163947838"],
["1267320759","2102593211","1831360854","1691591439","1672201908","61327345"],
["301343164","277158258","627925439","577508975","1896464649","907629062"],
["1964268932","929590164","1529686876","68630644","1663063136","254082844"],
["693529348","1815295486","1660565870","1226857377","156312343","1500907098"],
["1723158753","252348225","253985470","52424437","1605949937","576572581"],
["1781792048","1497492716","1951824572","1156925855","863650708","156447987"],
["876432605","458503399","283092867","247883110","1227074181","966219235"],
["1581118191","66527915","1577039825","1227961402","1738412997","1862462297"],
["679458448","338624032","34185999","253532412","65409631","563033132"],
["1011967612","1898273226","2124401156","105282260","1188226330","123913515"],
["1432513005","1083162825","1299704150","1184276814","339749370","1064298821"],
["145070927","1250457746","1977306722","1035124433","322154361","1782232869"],
["1934256728","1269667423","248999825","267009036","132507662","474021315"],
["1694871453","535187899","981724703","1180312550","74370795","277702656"],
["1850841912","1037634121","1967377497","1755127193","1566449422","1939039785"],
["1315699598","476111459","2058733537","332263289","1592057567","874912616"],
["774536537","170060616","2086574090","47894465","778021586","2115296942"],
["322468558","24934377","637275739","1596346002","1896623296","1814433409"],
["766517428","1263076038","358941187","2070217919","2108397185","1587546402"],
["345404490","1065320570","1231275245","1037359122","1286389839","2070140848"],
["746521574","835067673","311114030","1586400488","1406022058","1284151326"],
["1857315969","431410759","825259098","1717904860","503708539","2097758215"],
["1879479734","1863555039","2108235515","1833922769","1562707156","49484002"],
["1366768987","1050390036","1491845132","666041968","74368055","1254335623"],
["188857287","1161878039","1771805176","1457666227","1157868840","486461459"],
["261764705","1577846886","1332322961","10423372","640027252","1086814656"],
["111907709","542625019","2021749229","2013008690","523703611","1328833940"],
["1270684472","1675989474","394214608","538100201","1984625073","1560563159"],
["666555709","852557426","651115051","1878827907","953346499","619017191"],
["747972907","149382079","1393306586","1394823957","960994901","536632180"],
["80535893","229602380","1817483938","1455260088","484432","1869486290"],
["443556931","253108261","1609174393","1245931188","752691602","1668543792"],
["745497042","854686466","1834097777","642389535","1284043061","896553209"],
["532777064","1491985134","200157005","1378855967","1159213374","1797221037"],
["933463176","813761538","1124049829","1988347055","2115297439","1836576920"],
["194436043","1437728625","43998833","786326005","1130428925","1424571033"],
["2108272636","1410841489","753065553","2020187193","1644376367","670324352"],
["1362448669","752702510","1740531646","47989265","588634","1940480814"],
["439960422","528245604","179496898","235775013","59000527","1903150726"],
["103605138","249162711","1971219628","1958189530","423278905","1318354885"],
["321504059","1595801356","596911575","1361967073","459661104","599048233"],
["1610552125","73166668","444776743","1820306524","1180674369","1570356756"],
["1283703846","1024562975","958477092","1329464736","1758672211","899108631"],
["713626137","904634570","1902566483","1938333063","447549083","703262660"],
["1417291696","1717451368","354584524","832751684","930128006","1037860604"],
["1618745108","79533863","301008038","2091942909","1221962725","1524945081"],
["1490224031","349040760","393684137","484089443","1912848485","1790207999"],
["1930411520","642009628","1138820074","31855314","1177766391","1913457637"],
["618628623","139430131","904498895","925273128","2111653256","1012250155"]],
"queried_values_1" : [["316772341","1526280133","663010112","224983897","510598760","1109503351"],
["754832207","435790299","883623752","553207508","154784232","199176676"],
["689603315","1763523007","1720552945","1983603154","367841669","319325418"],
["1290247052","1120744584","193500372","294491115","951360807","891034447"]],
"proof of work" : "43",
"coeffs" : ["329725079","667313404","2083859876","1645693780"],
"inner_commitment_0" : "45555755014923146766476222823122654194153582923386372098787021809602091298670",
"inner_decommitment_0" :
["5462985033728555575703006689913665598917262836577853690370070265073174719979",
"1439463537898028163672031322449473184361454650351831891678420729829634422052",
"3457992889375232257419443221173459860069710025424105705342440558952480618733",
"1560236756242436900530359200127475252176167690738804367783218449678388748008",
"44704035195642947074853484428073358106137689624692539529545008114744235429869",
"1823722396825684657353300460130959150013626547749061669644787096209462088081",
"36173961213090661587166983586606381085726394354099487131775975739295244107044",
"44982414287882497869227873977178787463294164719732705041425254748922385395330",
"45309152665928987599895709328729600158686240484613667382186228529382344291065",
"6727671460449390719545303211873485738734289534911560838410979450236173250575",
"20261836592905556803606993653450529841562498576647326913889589677099044443740",
"5092741370842944151567575040255811031129982749889308756570945791769905531735",
"25669118913473215056777887930215946759071827607816131638248897691395809170466"],
"inner_evals_subset_0" :
[["1779738283","1440487135","622229563","2027928845"],
["3310770","2003547458","1663490902","2105455978"],
["824310523","1757518542","231582441","427507918"]],
"inner_commitment_1" : "36809196688918736785151875655523363274066842107451407514854169291514772437712",
"inner_decommitment_1" :
["45836684762941279488847688946861562185184567552524644394596887832754938056979",
"19073399120361742411025793373596546929954434607182120440274945820526786807031",
"10149408627738268983458395386544516732232569888429637086606165543488301475287",
"1134561121197915913438467336074096605141057500441242329860673894339623319838",
"4117331370612733116088255919125911273783755080775248279590234839282346655209",
"10513576479117646185507147947418059201301955276798193127377481002238952180861",
"23342730846670478933566004300392843380760787769231395287565422468403932432226",
"18740447275326464369699774196183138103532792867824812360537906002797944316369",
"14414140432147963417180883803716587948502362440108737731008978168633980398219",
"12250577863311462987061070846596950168057013378171600204130942328863618210853"],
"inner_evals_subset_1" :
[["1993244979","993480712","1202910330","51538179"],
["1550507975","1444313548","117070947","1740854590"],
["342709844","601149328","1436490544","1384381104"]],
"inner_commitment_2" : "881819876414071785116043072319901261019430342891967010427739931379717752179",
"inner_decommitment_2" :
["3578061893060038632121836895066391994380789049478144565795630968916166958170",
"37159117331259094338569510749131708143302385613991325356583099983565421599794",
"37757292341028790385725896863222501721662517262004109998718347744172059822092",
"32039058093629437909283983486827138804118789892152561698628703527746673098335",
"11826129489683399268430621966121477905203281789801973078750629518404118320344",
"12209931412475259378582945309327673629248630080862144675946167350377528046777",
"23652788355891929123107872712785512551735081591713325931570432712643436668697"],
"inner_evals_subset_2" :
[["1184374976","203117688","803515935","1781630737"],
["716686867","138132852","2024080584","392488646"],
["606958913","308986056","258114411","2075401741"]],
"inner_commitment_3" : "12027868599227153144742193285247060272784688895537159385948117930165143367635",
"inner_decommitment_3" :
["46510393127320994984678433868779353751380750916123221099220042551249644407575",
"39205406309151711272632862117612882165913578213907083849769292791373274291092",
"1794620160371318211300608665903463515892826349097148444913826356018179256828",
"30764929521910551485060681298929856016753563990361023701791813032517030674324"],
"inner_evals_subset_3" :
[["1682558787","129089003","784689440","491206249"],
["2038210709","1600238918","655676259","1542271403"],
["1014579052","1384080403","862591487","1941843578"]],
"inner_commitment_4" : "14412699124489796400221638504502701429059474709940645969751865021865037702257",
"inner_decommitment_4" :
["19324767949149751902195880760061491860991545124249052461044586503970003688610"],
"inner_evals_subset_4" :
[["683258805","1002722262","1583421272","1748673499"],
["2101847208","689925082","1602280602","1942656531"],
["1952987285","1995490213","2082219584","1620868519"]],
"inner_commitment_5" : "7898658461322497542615494384418597990320440781916208640366283709415937814000",
"inner_decommitment_5" :
[],
"inner_evals_subset_5" :
[["1862940297","2014478284","1043383827","560191545"]]
}

View File

@ -0,0 +1,84 @@
use num_traits::{One, Zero};
use super::EvalAtRow;
use crate::core::backend::{Backend, Column};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::{CanonicCoset, CirclePoly};
use crate::core::utils::circle_domain_order_to_coset_order;
/// Evaluates expressions at a trace domain row, and asserts constraints. Mainly used for testing.
pub struct AssertEvaluator<'a> {
pub trace: &'a TreeVec<Vec<Vec<BaseField>>>,
pub col_index: TreeVec<usize>,
pub row: usize,
}
impl<'a> AssertEvaluator<'a> {
pub fn new(trace: &'a TreeVec<Vec<Vec<BaseField>>>, row: usize) -> Self {
Self {
trace,
col_index: TreeVec::new(vec![0; trace.len()]),
row,
}
}
}
impl<'a> EvalAtRow for AssertEvaluator<'a> {
type F = BaseField;
type EF = SecureField;
fn next_interaction_mask<const N: usize>(
&mut self,
interaction: usize,
offsets: [isize; N],
) -> [Self::F; N] {
let col_index = self.col_index[interaction];
self.col_index[interaction] += 1;
offsets.map(|off| {
// The mask row might wrap around the column size.
let col_size = self.trace[interaction][col_index].len() as isize;
self.trace[interaction][col_index]
[(self.row as isize + off).rem_euclid(col_size) as usize]
})
}
fn add_constraint<G>(&mut self, constraint: G)
where
Self::EF: std::ops::Mul<G, Output = Self::EF>,
{
// Cast to SecureField.
let res = SecureField::one() * constraint;
// The constraint should be zero at the given row, since we are evaluating on the trace
// domain.
assert_eq!(res, SecureField::zero(), "row: {}", self.row);
}
fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF {
SecureField::from_m31_array(values)
}
}
pub fn assert_constraints<B: Backend>(
trace_polys: &TreeVec<Vec<CirclePoly<B>>>,
trace_domain: CanonicCoset,
assert_func: impl Fn(AssertEvaluator<'_>),
) {
let traces = trace_polys.as_ref().map(|tree| {
tree.iter()
.map(|poly| {
circle_domain_order_to_coset_order(
&poly
.evaluate(trace_domain.circle_domain())
.bit_reverse()
.values
.to_cpu(),
)
})
.collect()
});
for row in 0..trace_domain.size() {
let eval = AssertEvaluator::new(&traces, row);
assert_func(eval);
}
}

View File

@ -0,0 +1,210 @@
use std::borrow::Cow;
use std::iter::zip;
use std::ops::Deref;
use itertools::Itertools;
use tracing::{span, Level};
use super::{EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator};
use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use crate::core::air::{Component, ComponentProver, Trace};
use crate::core::backend::simd::column::VeryPackedSecureColumnByCoords;
use crate::core::backend::simd::m31::LOG_N_LANES;
use crate::core::backend::simd::very_packed_m31::{VeryPackedBaseField, LOG_N_VERY_PACKED_ELEMS};
use crate::core::backend::simd::SimdBackend;
use crate::core::circle::CirclePoint;
use crate::core::constraints::coset_vanishing;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;
use crate::core::pcs::{TreeSubspan, TreeVec};
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps};
use crate::core::poly::BitReversedOrder;
use crate::core::{utils, ColumnVec};
// TODO(andrew): Docs.
// TODO(andrew): Consider better location for this.
#[derive(Debug, Default)]
pub struct TraceLocationAllocator {
/// Mapping of tree index to next available column offset.
next_tree_offsets: TreeVec<usize>,
}
impl TraceLocationAllocator {
fn next_for_structure<T>(&mut self, structure: &TreeVec<ColumnVec<T>>) -> TreeVec<TreeSubspan> {
if structure.len() > self.next_tree_offsets.len() {
self.next_tree_offsets.resize(structure.len(), 0);
}
TreeVec::new(
zip(&mut *self.next_tree_offsets, &**structure)
.enumerate()
.map(|(tree_index, (offset, cols))| {
let col_start = *offset;
let col_end = col_start + cols.len();
*offset = col_end;
TreeSubspan {
tree_index,
col_start,
col_end,
}
})
.collect(),
)
}
}
/// A component defined solely in means of the constraints framework.
/// Implementing this trait introduces implementations for [`Component`] and [`ComponentProver`] for
/// the SIMD backend.
/// Note that the constraint framework only support components with columns of the same size.
pub trait FrameworkEval {
fn log_size(&self) -> u32;
fn max_constraint_log_degree_bound(&self) -> u32;
fn evaluate<E: EvalAtRow>(&self, eval: E) -> E;
}
pub struct FrameworkComponent<C: FrameworkEval> {
eval: C,
trace_locations: TreeVec<TreeSubspan>,
}
impl<E: FrameworkEval> FrameworkComponent<E> {
pub fn new(provider: &mut TraceLocationAllocator, eval: E) -> Self {
let eval_tree_structure = eval.evaluate(InfoEvaluator::default()).mask_offsets;
let trace_locations = provider.next_for_structure(&eval_tree_structure);
Self {
eval,
trace_locations,
}
}
}
impl<E: FrameworkEval> Component for FrameworkComponent<E> {
fn n_constraints(&self) -> usize {
self.eval.evaluate(InfoEvaluator::default()).n_constraints
}
fn max_constraint_log_degree_bound(&self) -> u32 {
self.eval.max_constraint_log_degree_bound()
}
fn trace_log_degree_bounds(&self) -> TreeVec<ColumnVec<u32>> {
TreeVec::new(
self.eval
.evaluate(InfoEvaluator::default())
.mask_offsets
.iter()
.map(|tree_masks| vec![self.eval.log_size(); tree_masks.len()])
.collect(),
)
}
fn mask_points(
&self,
point: CirclePoint<SecureField>,
) -> TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>> {
let info = self.eval.evaluate(InfoEvaluator::default());
let trace_step = CanonicCoset::new(self.eval.log_size()).step();
info.mask_offsets.map_cols(|col_mask| {
col_mask
.iter()
.map(|off| point + trace_step.mul_signed(*off).into_ef())
.collect()
})
}
fn evaluate_constraint_quotients_at_point(
&self,
point: CirclePoint<SecureField>,
mask: &TreeVec<ColumnVec<Vec<SecureField>>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
) {
self.eval.evaluate(PointEvaluator::new(
mask.sub_tree(&self.trace_locations),
evaluation_accumulator,
coset_vanishing(CanonicCoset::new(self.eval.log_size()).coset, point).inverse(),
));
}
}
impl<E: FrameworkEval> ComponentProver<SimdBackend> for FrameworkComponent<E> {
fn evaluate_constraint_quotients_on_domain(
&self,
trace: &Trace<'_, SimdBackend>,
evaluation_accumulator: &mut DomainEvaluationAccumulator<SimdBackend>,
) {
let eval_domain = CanonicCoset::new(self.max_constraint_log_degree_bound()).circle_domain();
let trace_domain = CanonicCoset::new(self.eval.log_size());
let component_polys = trace.polys.sub_tree(&self.trace_locations);
let component_evals = trace.evals.sub_tree(&self.trace_locations);
// Extend trace if necessary.
// TODO(spapini): Don't extend when eval_size < committed_size. Instead, pick a good
// subdomain.
let need_to_extend = component_evals
.iter()
.flatten()
.any(|c| c.domain != eval_domain);
let trace: TreeVec<
Vec<Cow<'_, CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>>,
> = if need_to_extend {
let _span = span!(Level::INFO, "Extension").entered();
let twiddles = SimdBackend::precompute_twiddles(eval_domain.half_coset);
component_polys
.as_cols_ref()
.map_cols(|col| Cow::Owned(col.evaluate_with_twiddles(eval_domain, &twiddles)))
} else {
component_evals.clone().map_cols(|c| Cow::Borrowed(*c))
};
// Denom inverses.
let log_expand = eval_domain.log_size() - trace_domain.log_size();
let mut denom_inv = (0..1 << log_expand)
.map(|i| coset_vanishing(trace_domain.coset(), eval_domain.at(i)).inverse())
.collect_vec();
utils::bit_reverse(&mut denom_inv);
// Accumulator.
let [mut accum] =
evaluation_accumulator.columns([(eval_domain.log_size(), self.n_constraints())]);
accum.random_coeff_powers.reverse();
let _span = span!(Level::INFO, "Constraint pointwise eval").entered();
let col = unsafe { VeryPackedSecureColumnByCoords::transform_under_mut(accum.col) };
for vec_row in 0..(1 << (eval_domain.log_size() - LOG_N_LANES - LOG_N_VERY_PACKED_ELEMS)) {
let trace_cols = trace.as_cols_ref().map_cols(|c| c.as_ref());
// Evaluate constrains at row.
let eval = SimdDomainEvaluator::new(
&trace_cols,
vec_row,
&accum.random_coeff_powers,
trace_domain.log_size(),
eval_domain.log_size(),
);
let row_res = self.eval.evaluate(eval).row_res;
// Finalize row.
unsafe {
let denom_inv = VeryPackedBaseField::broadcast(
denom_inv[vec_row
>> (trace_domain.log_size() - LOG_N_LANES - LOG_N_VERY_PACKED_ELEMS)],
);
col.set_packed(vec_row, col.packed_at(vec_row) + row_res * denom_inv)
}
}
}
}
impl<E: FrameworkEval> Deref for FrameworkComponent<E> {
type Target = E;
fn deref(&self) -> &E {
&self.eval
}
}

View File

@ -0,0 +1,37 @@
use num_traits::One;
use crate::core::backend::{Backend, Col, Column};
use crate::core::fields::m31::BaseField;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index};
/// Generates a column with a single one at the first position, and zeros elsewhere.
pub fn gen_is_first<B: Backend>(log_size: u32) -> CircleEvaluation<B, BaseField, BitReversedOrder> {
let mut col = Col::<B, BaseField>::zeros(1 << log_size);
col.set(0, BaseField::one());
CircleEvaluation::new(CanonicCoset::new(log_size).circle_domain(), col)
}
/// Generates a column with `1` at every `2^log_step` positions, `0` elsewhere, shifted by offset.
// TODO(andrew): Consider optimizing. Is a quotients of two coset_vanishing (use succinct rep for
// verifier).
pub fn gen_is_step_with_offset<B: Backend>(
log_size: u32,
log_step: u32,
offset: usize,
) -> CircleEvaluation<B, BaseField, BitReversedOrder> {
let mut col = Col::<B, BaseField>::zeros(1 << log_size);
let size = 1 << log_size;
let step = 1 << log_step;
let step_offset = offset % step;
for i in (step_offset..size).step_by(step) {
let circle_domain_index = coset_index_to_circle_domain_index(i, log_size);
let circle_domain_index_bit_rev = bit_reverse_index(circle_domain_index, log_size);
col.set(circle_domain_index_bit_rev, BaseField::one());
}
CircleEvaluation::new(CanonicCoset::new(log_size).circle_domain(), col)
}

View File

@ -0,0 +1,48 @@
use std::ops::Mul;
use num_traits::One;
use super::EvalAtRow;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::pcs::TreeVec;
/// Collects information about the constraints.
/// This includes mask offsets and columns at each interaction, and the number of constraints.
#[derive(Default)]
pub struct InfoEvaluator {
pub mask_offsets: TreeVec<Vec<Vec<isize>>>,
pub n_constraints: usize,
}
impl InfoEvaluator {
pub fn new() -> Self {
Self::default()
}
}
impl EvalAtRow for InfoEvaluator {
type F = BaseField;
type EF = SecureField;
fn next_interaction_mask<const N: usize>(
&mut self,
interaction: usize,
offsets: [isize; N],
) -> [Self::F; N] {
// Check if requested a mask from a new interaction
if self.mask_offsets.len() <= interaction {
// Extend `mask_offsets` so that `interaction` is the last index.
self.mask_offsets.resize(interaction + 1, vec![]);
}
self.mask_offsets[interaction].push(offsets.into_iter().collect());
[BaseField::one(); N]
}
fn add_constraint<G>(&mut self, _constraint: G)
where
Self::EF: Mul<G, Output = Self::EF>,
{
self.n_constraints += 1;
}
fn combine_ef(_values: [Self::F; 4]) -> Self::EF {
SecureField::one()
}
}

View File

@ -0,0 +1,315 @@
use std::ops::{Mul, Sub};
use itertools::Itertools;
use num_traits::{One, Zero};
use super::EvalAtRow;
use crate::core::backend::simd::column::SecureColumn;
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum;
use crate::core::backend::simd::qm31::PackedSecureField;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::Column;
use crate::core::channel::Channel;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::{SecureColumnByCoords, SECURE_EXTENSION_DEGREE};
use crate::core::fields::FieldExpOps;
use crate::core::lookups::utils::Fraction;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::ColumnVec;
/// Evaluates constraints for batched logups.
/// These constraint enforce the sum of multiplicity_i / (z + sum_j alpha^j * x_j) = claimed_sum.
/// BATCH_SIZE is the number of fractions to batch together. The degree of the resulting constraints
/// will be BATCH_SIZE + 1.
pub struct LogupAtRow<const BATCH_SIZE: usize, E: EvalAtRow> {
/// The index of the interaction used for the cumulative sum columns.
pub interaction: usize,
/// Queue of fractions waiting to be batched together.
pub queue: [(E::EF, E::EF); BATCH_SIZE],
/// Number of fractions in the queue.
pub queue_size: usize,
/// A constant to subtract from each row, to make the totall sum of the last column zero.
/// In other words, claimed_sum / 2^log_size.
/// This is used to make the constraint uniform.
pub cumsum_shift: SecureField,
/// The evaluation of the last cumulative sum column.
pub prev_col_cumsum: E::EF,
is_finalized: bool,
}
impl<const BATCH_SIZE: usize, E: EvalAtRow> LogupAtRow<BATCH_SIZE, E> {
pub fn new(interaction: usize, claimed_sum: SecureField, log_size: u32) -> Self {
Self {
interaction,
queue: [(E::EF::zero(), E::EF::zero()); BATCH_SIZE],
queue_size: 0,
cumsum_shift: claimed_sum / BaseField::from_u32_unchecked(1 << log_size),
prev_col_cumsum: E::EF::zero(),
is_finalized: false,
}
}
pub fn push_lookup<const N: usize>(
&mut self,
eval: &mut E,
numerator: E::EF,
values: &[E::F],
lookup_elements: &LookupElements<N>,
) {
let shifted_value = lookup_elements.combine(values);
self.push_frac(eval, numerator, shifted_value);
}
pub fn push_frac(&mut self, eval: &mut E, numerator: E::EF, denominator: E::EF) {
if self.queue_size < BATCH_SIZE {
self.queue[self.queue_size] = (numerator, denominator);
self.queue_size += 1;
return;
}
// Compute sum_i pi/qi over batch, as a fraction, num/denom.
let (num, denom) = self.fold_queue();
self.queue[0] = (numerator, denominator);
self.queue_size = 1;
// Add a constraint that num / denom = diff.
let cur_cumsum = eval.next_extension_interaction_mask(self.interaction, [0])[0];
let diff = cur_cumsum - self.prev_col_cumsum;
self.prev_col_cumsum = cur_cumsum;
eval.add_constraint(diff * denom - num);
}
pub fn add_frac(&mut self, eval: &mut E, fraction: Fraction<E::EF, E::EF>) {
// Add a constraint that num / denom = diff.
let cur_cumsum = eval.next_extension_interaction_mask(self.interaction, [0])[0];
let diff = cur_cumsum - self.prev_col_cumsum;
self.prev_col_cumsum = cur_cumsum;
eval.add_constraint(diff * fraction.denominator - fraction.numerator);
}
pub fn finalize(mut self, eval: &mut E) {
assert!(!self.is_finalized, "LogupAtRow was already finalized");
let (num, denom) = self.fold_queue();
let [cur_cumsum, prev_row_cumsum] =
eval.next_extension_interaction_mask(self.interaction, [0, -1]);
let diff = cur_cumsum - prev_row_cumsum - self.prev_col_cumsum;
// Instead of checking diff = num / denom, check diff = num / denom - cumsum_shift.
// This makes (num / denom - cumsum_shift) have sum zero, which makes the constraint
// uniform - apply on all rows.
let fixed_diff = diff + self.cumsum_shift;
eval.add_constraint(fixed_diff * denom - num);
self.is_finalized = true;
}
fn fold_queue(&self) -> (E::EF, E::EF) {
self.queue[0..self.queue_size]
.iter()
.copied()
.fold((E::EF::zero(), E::EF::one()), |(p0, q0), (pi, qi)| {
(p0 * qi + pi * q0, qi * q0)
})
}
}
/// Ensures that the LogupAtRow is finalized.
/// LogupAtRow should be finalized exactly once.
impl<const BATCH_SIZE: usize, E: EvalAtRow> Drop for LogupAtRow<BATCH_SIZE, E> {
fn drop(&mut self) {
assert!(self.is_finalized, "LogupAtRow was not finalized");
}
}
/// Interaction elements for the logup protocol.
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct LookupElements<const N: usize> {
pub z: SecureField,
pub alpha: SecureField,
alpha_powers: [SecureField; N],
}
impl<const N: usize> LookupElements<N> {
pub fn draw(channel: &mut impl Channel) -> Self {
let [z, alpha] = channel.draw_felts(2).try_into().unwrap();
let mut cur = SecureField::one();
let alpha_powers = std::array::from_fn(|_| {
let res = cur;
cur *= alpha;
res
});
Self {
z,
alpha,
alpha_powers,
}
}
pub fn combine<F: Copy, EF>(&self, values: &[F]) -> EF
where
EF: Copy + Zero + From<F> + From<SecureField> + Mul<F, Output = EF> + Sub<EF, Output = EF>,
{
EF::from(values[0])
+ values[1..]
.iter()
.zip(self.alpha_powers.iter())
.fold(EF::zero(), |acc, (&value, &power)| {
acc + EF::from(power) * value
})
- EF::from(self.z)
}
// TODO(spapini): Try to remove this.
pub fn dummy() -> Self {
Self {
z: SecureField::one(),
alpha: SecureField::one(),
alpha_powers: [SecureField::one(); N],
}
}
}
// SIMD backend generator for logup interaction trace.
pub struct LogupTraceGenerator {
log_size: u32,
/// Current allocated interaction columns.
trace: Vec<SecureColumnByCoords<SimdBackend>>,
/// Denominator expressions (z + sum_i alpha^i * x_i) being generated for the current lookup.
denom: SecureColumn,
/// Preallocated buffer for the Inverses of the denominators.
denom_inv: SecureColumn,
}
impl LogupTraceGenerator {
pub fn new(log_size: u32) -> Self {
let trace = vec![];
let denom = SecureColumn::zeros(1 << log_size);
let denom_inv = SecureColumn::zeros(1 << log_size);
Self {
log_size,
trace,
denom,
denom_inv,
}
}
/// Allocate a new lookup column.
pub fn new_col(&mut self) -> LogupColGenerator<'_> {
let log_size = self.log_size;
LogupColGenerator {
gen: self,
numerator: SecureColumnByCoords::<SimdBackend>::zeros(1 << log_size),
}
}
/// Finalize the trace. Returns the trace and the claimed sum of the last column.
pub fn finalize(
mut self,
) -> (
ColumnVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>,
SecureField,
) {
// Compute claimed sum.
let mut last_col_coords = self.trace.pop().unwrap().columns;
let packed_sums: [PackedBaseField; SECURE_EXTENSION_DEGREE] = last_col_coords
.each_ref()
.map(|c| c.data.iter().copied().sum());
let base_sums = packed_sums.map(|s| s.pointwise_sum());
let claimed_sum = SecureField::from_m31_array(base_sums);
// Shift the last column to make the sum zero.
let cumsum_shift = claimed_sum / BaseField::from_u32_unchecked(1 << self.log_size);
last_col_coords.iter_mut().enumerate().for_each(|(i, c)| {
c.data
.iter_mut()
.for_each(|x| *x -= PackedBaseField::broadcast(cumsum_shift.to_m31_array()[i]))
});
// Prefix sum the last column.
let coord_prefix_sum = last_col_coords.map(inclusive_prefix_sum);
self.trace.push(SecureColumnByCoords {
columns: coord_prefix_sum,
});
let trace = self
.trace
.into_iter()
.flat_map(|eval| {
eval.columns.map(|c| {
CircleEvaluation::<SimdBackend, _, BitReversedOrder>::new(
CanonicCoset::new(self.log_size).circle_domain(),
c,
)
})
})
.collect_vec();
(trace, claimed_sum)
}
}
/// Trace generator for a single lookup column.
pub struct LogupColGenerator<'a> {
gen: &'a mut LogupTraceGenerator,
/// Numerator expressions (i.e. multiplicities) being generated for the current lookup.
numerator: SecureColumnByCoords<SimdBackend>,
}
impl<'a> LogupColGenerator<'a> {
/// Write a fraction to the column at a row.
pub fn write_frac(
&mut self,
vec_row: usize,
numerator: PackedSecureField,
denom: PackedSecureField,
) {
debug_assert!(
denom.to_array().iter().all(|x| *x != SecureField::zero()),
"{:?}",
("denom at vec_row {} is zero {}", denom, vec_row)
);
unsafe {
self.numerator.set_packed(vec_row, numerator);
*self.gen.denom.data.get_unchecked_mut(vec_row) = denom;
}
}
/// Finalizes generating the column.
pub fn finalize_col(mut self) {
FieldExpOps::batch_inverse(&self.gen.denom.data, &mut self.gen.denom_inv.data);
for vec_row in 0..(1 << (self.gen.log_size - LOG_N_LANES)) {
unsafe {
let value = self.numerator.packed_at(vec_row)
* *self.gen.denom_inv.data.get_unchecked(vec_row);
let prev_value = self
.gen
.trace
.last()
.map(|col| col.packed_at(vec_row))
.unwrap_or_else(PackedSecureField::zero);
self.numerator.set_packed(vec_row, value + prev_value)
};
}
self.gen.trace.push(self.numerator)
}
}
#[cfg(test)]
mod tests {
use num_traits::One;
use super::LogupAtRow;
use crate::constraint_framework::InfoEvaluator;
use crate::core::fields::qm31::SecureField;
#[test]
#[should_panic]
fn test_logup_not_finalized_panic() {
let mut logup = LogupAtRow::<2, InfoEvaluator>::new(1, SecureField::one(), 7);
logup.push_frac(
&mut InfoEvaluator::default(),
SecureField::one(),
SecureField::one(),
);
}
}

View File

@ -0,0 +1,97 @@
/// ! This module contains helpers to express and use constraints for components.
mod assert;
mod component;
pub mod constant_columns;
mod info;
pub mod logup;
mod point;
mod simd_domain;
use std::array;
use std::fmt::Debug;
use std::ops::{Add, AddAssign, Mul, Neg, Sub};
pub use assert::{assert_constraints, AssertEvaluator};
pub use component::{FrameworkComponent, FrameworkEval, TraceLocationAllocator};
pub use info::InfoEvaluator;
use num_traits::{One, Zero};
pub use point::PointEvaluator;
pub use simd_domain::SimdDomainEvaluator;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
use crate::core::fields::FieldExpOps;
/// A trait for evaluating expressions at some point or row.
pub trait EvalAtRow {
// TODO(spapini): Use a better trait for these, like 'Algebra' or something.
/// The field type holding values of columns for the component. These are the inputs to the
/// constraints. It might be [BaseField] packed types, or even [SecureField], when evaluating
/// the columns out of domain.
type F: FieldExpOps
+ Copy
+ Debug
+ Zero
+ Neg<Output = Self::F>
+ AddAssign
+ AddAssign<BaseField>
+ Add<Self::F, Output = Self::F>
+ Sub<Self::F, Output = Self::F>
+ Mul<BaseField, Output = Self::F>
+ Add<SecureField, Output = Self::EF>
+ Mul<SecureField, Output = Self::EF>
+ Neg<Output = Self::F>
+ From<BaseField>;
/// A field type representing the closure of `F` with multiplying by [SecureField]. Constraints
/// usually get multiplied by [SecureField] values for security.
type EF: One
+ Copy
+ Debug
+ Zero
+ From<Self::F>
+ Neg<Output = Self::EF>
+ AddAssign
+ Add<SecureField, Output = Self::EF>
+ Sub<SecureField, Output = Self::EF>
+ Mul<SecureField, Output = Self::EF>
+ Add<Self::F, Output = Self::EF>
+ Mul<Self::F, Output = Self::EF>
+ Sub<Self::EF, Output = Self::EF>
+ Mul<Self::EF, Output = Self::EF>
+ From<SecureField>
+ From<Self::F>;
/// Returns the next mask value for the first interaction at offset 0.
fn next_trace_mask(&mut self) -> Self::F {
let [mask_item] = self.next_interaction_mask(0, [0]);
mask_item
}
/// Returns the mask values of the given offsets for the next column in the interaction.
fn next_interaction_mask<const N: usize>(
&mut self,
interaction: usize,
offsets: [isize; N],
) -> [Self::F; N];
/// Returns the extension mask values of the given offsets for the next extension degree many
/// columns in the interaction.
fn next_extension_interaction_mask<const N: usize>(
&mut self,
interaction: usize,
offsets: [isize; N],
) -> [Self::EF; N] {
let res_col_major = array::from_fn(|_| self.next_interaction_mask(interaction, offsets));
array::from_fn(|i| Self::combine_ef(res_col_major.map(|c| c[i])))
}
/// Adds a constraint to the component.
fn add_constraint<G>(&mut self, constraint: G)
where
Self::EF: Mul<G, Output = Self::EF>;
/// Combines 4 base field values into a single extension field value.
fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF;
}

View File

@ -0,0 +1,57 @@
use std::ops::Mul;
use super::EvalAtRow;
use crate::core::air::accumulation::PointEvaluationAccumulator;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
use crate::core::pcs::TreeVec;
use crate::core::ColumnVec;
/// Evaluates expressions at a point out of domain.
pub struct PointEvaluator<'a> {
pub mask: TreeVec<ColumnVec<&'a Vec<SecureField>>>,
pub evaluation_accumulator: &'a mut PointEvaluationAccumulator,
pub col_index: Vec<usize>,
pub denom_inverse: SecureField,
}
impl<'a> PointEvaluator<'a> {
pub fn new(
mask: TreeVec<ColumnVec<&'a Vec<SecureField>>>,
evaluation_accumulator: &'a mut PointEvaluationAccumulator,
denom_inverse: SecureField,
) -> Self {
let col_index = vec![0; mask.len()];
Self {
mask,
evaluation_accumulator,
col_index,
denom_inverse,
}
}
}
impl<'a> EvalAtRow for PointEvaluator<'a> {
type F = SecureField;
type EF = SecureField;
fn next_interaction_mask<const N: usize>(
&mut self,
interaction: usize,
_offsets: [isize; N],
) -> [Self::F; N] {
let col_index = self.col_index[interaction];
self.col_index[interaction] += 1;
let mask = self.mask[interaction][col_index].clone();
assert_eq!(mask.len(), N);
mask.try_into().unwrap()
}
fn add_constraint<G>(&mut self, constraint: G)
where
Self::EF: Mul<G, Output = Self::EF>,
{
self.evaluation_accumulator
.accumulate(self.denom_inverse * constraint);
}
fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF {
SecureField::from_partial_evals(values)
}
}

View File

@ -0,0 +1,106 @@
use std::ops::Mul;
use num_traits::Zero;
use super::EvalAtRow;
use crate::core::backend::simd::column::VeryPackedBaseColumn;
use crate::core::backend::simd::m31::LOG_N_LANES;
use crate::core::backend::simd::very_packed_m31::{
VeryPackedBaseField, VeryPackedSecureField, LOG_N_VERY_PACKED_ELEMS,
};
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::Column;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::CircleEvaluation;
use crate::core::poly::BitReversedOrder;
use crate::core::utils::offset_bit_reversed_circle_domain_index;
/// Evaluates constraints at an evaluation domain points.
pub struct SimdDomainEvaluator<'a> {
pub trace_eval:
&'a TreeVec<Vec<&'a CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>>,
pub column_index_per_interaction: Vec<usize>,
/// The row index of the simd-vector row to evaluate the constraints at.
pub vec_row: usize,
pub random_coeff_powers: &'a [SecureField],
pub row_res: VeryPackedSecureField,
pub constraint_index: usize,
pub domain_log_size: u32,
pub eval_domain_log_size: u32,
}
impl<'a> SimdDomainEvaluator<'a> {
pub fn new(
trace_eval: &'a TreeVec<Vec<&CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>>,
vec_row: usize,
random_coeff_powers: &'a [SecureField],
domain_log_size: u32,
eval_log_size: u32,
) -> Self {
Self {
trace_eval,
column_index_per_interaction: vec![0; trace_eval.len()],
vec_row,
random_coeff_powers,
row_res: VeryPackedSecureField::zero(),
constraint_index: 0,
domain_log_size,
eval_domain_log_size: eval_log_size,
}
}
}
impl<'a> EvalAtRow for SimdDomainEvaluator<'a> {
type F = VeryPackedBaseField;
type EF = VeryPackedSecureField;
// TODO(spapini): Remove all boundary checks.
fn next_interaction_mask<const N: usize>(
&mut self,
interaction: usize,
offsets: [isize; N],
) -> [Self::F; N] {
let col_index = self.column_index_per_interaction[interaction];
self.column_index_per_interaction[interaction] += 1;
offsets.map(|off| {
// If the offset is 0, we can just return the value directly from this row.
if off == 0 {
unsafe {
let col = &self
.trace_eval
.get_unchecked(interaction)
.get_unchecked(col_index)
.values;
let very_packed_col = VeryPackedBaseColumn::transform_under_ref(col);
return *very_packed_col.data.get_unchecked(self.vec_row);
};
}
// Otherwise, we need to look up the value at the offset.
// Since the domain is bit-reversed circle domain ordered, we need to look up the value
// at the bit-reversed natural order index at an offset.
VeryPackedBaseField::from_array(std::array::from_fn(|i| {
let row_index = offset_bit_reversed_circle_domain_index(
(self.vec_row << (LOG_N_LANES + LOG_N_VERY_PACKED_ELEMS)) + i,
self.domain_log_size,
self.eval_domain_log_size,
off,
);
self.trace_eval[interaction][col_index].at(row_index)
}))
})
}
fn add_constraint<G>(&mut self, constraint: G)
where
Self::EF: Mul<G, Output = Self::EF>,
{
self.row_res +=
VeryPackedSecureField::broadcast(self.random_coeff_powers[self.constraint_index])
* constraint;
self.constraint_index += 1;
}
fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF {
VeryPackedSecureField::from_very_packed_m31s(values)
}
}

View File

@ -0,0 +1,297 @@
//! Accumulators for a random linear combination of circle polynomials.
//! Given N polynomials, u_0(P), ... u_{N-1}(P), and a random alpha, the combined polynomial is
//! defined as
//! f(p) = sum_i alpha^{N-1-i} u_i(P).
use itertools::Itertools;
use tracing::{span, Level};
use crate::core::backend::{Backend, Col, Column, CpuBackend};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumnByCoords;
use crate::core::fields::FieldOps;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, CirclePoly, SecureCirclePoly};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::generate_secure_powers;
/// Accumulates N evaluations of u_i(P0) at a single point.
/// Computes f(P0), the combined polynomial at that point.
/// For n accumulated evaluations, the i'th evaluation is multiplied by alpha^(N-1-i).
pub struct PointEvaluationAccumulator {
random_coeff: SecureField,
accumulation: SecureField,
}
impl PointEvaluationAccumulator {
/// Creates a new accumulator.
/// `random_coeff` should be a secure random field element, drawn from the channel.
pub fn new(random_coeff: SecureField) -> Self {
Self {
random_coeff,
accumulation: SecureField::default(),
}
}
/// Accumulates u_i(P0), a polynomial evaluation at a P0 in reverse order.
pub fn accumulate(&mut self, evaluation: SecureField) {
self.accumulation = self.accumulation * self.random_coeff + evaluation;
}
pub fn finalize(self) -> SecureField {
self.accumulation
}
}
// TODO(ShaharS), rename terminology to constraints instead of columns.
/// Accumulates evaluations of u_i(P), each at an evaluation domain of the size of that polynomial.
/// Computes the coefficients of f(P).
pub struct DomainEvaluationAccumulator<B: Backend> {
random_coeff_powers: Vec<SecureField>,
/// Accumulated evaluations for each log_size.
/// Each `sub_accumulation` holds the sum over all columns i of that log_size, of
/// `evaluation_i * alpha^(N - 1 - i)`
/// where `N` is the total number of evaluations.
sub_accumulations: Vec<Option<SecureColumnByCoords<B>>>,
}
impl<B: Backend> DomainEvaluationAccumulator<B> {
/// Creates a new accumulator.
/// `random_coeff` should be a secure random field element, drawn from the channel.
/// `max_log_size` is the maximum log_size of the accumulated evaluations.
pub fn new(random_coeff: SecureField, max_log_size: u32, total_columns: usize) -> Self {
let max_log_size = max_log_size as usize;
Self {
random_coeff_powers: generate_secure_powers(random_coeff, total_columns),
sub_accumulations: (0..(max_log_size + 1)).map(|_| None).collect(),
}
}
/// Gets accumulators for some sizes.
/// `n_cols_per_size` is an array of pairs (log_size, n_cols).
/// For each entry, a [ColumnAccumulator] is returned, expecting to accumulate `n_cols`
/// evaluations of size `log_size`.
/// The array size, `N`, is the number of different sizes.
pub fn columns<const N: usize>(
&mut self,
n_cols_per_size: [(u32, usize); N],
) -> [ColumnAccumulator<'_, B>; N] {
self.sub_accumulations
.get_many_mut(n_cols_per_size.map(|(log_size, _)| log_size as usize))
.unwrap_or_else(|e| panic!("invalid log_sizes: {}", e))
.into_iter()
.zip(n_cols_per_size)
.map(|(col, (log_size, n_cols))| {
let random_coeffs = self
.random_coeff_powers
.split_off(self.random_coeff_powers.len() - n_cols);
ColumnAccumulator {
random_coeff_powers: random_coeffs,
col: col.get_or_insert_with(|| SecureColumnByCoords::zeros(1 << log_size)),
}
})
.collect_vec()
.try_into()
.unwrap_or_else(|_| unreachable!())
}
/// Returns the log size of the resulting polynomial.
pub fn log_size(&self) -> u32 {
(self.sub_accumulations.len() - 1) as u32
}
}
pub trait AccumulationOps: FieldOps<BaseField> + Sized {
/// Accumulates other into column:
/// column = column + other.
fn accumulate(column: &mut SecureColumnByCoords<Self>, other: &SecureColumnByCoords<Self>);
}
impl<B: Backend> DomainEvaluationAccumulator<B> {
/// Computes f(P) as coefficients.
pub fn finalize(self) -> SecureCirclePoly<B> {
assert_eq!(
self.random_coeff_powers.len(),
0,
"not all random coefficients were used"
);
let log_size = self.log_size();
let _span = span!(Level::INFO, "Constraints interpolation").entered();
let mut cur_poly: Option<SecureCirclePoly<B>> = None;
let twiddles = B::precompute_twiddles(
CanonicCoset::new(self.log_size())
.circle_domain()
.half_coset,
);
for (log_size, values) in self.sub_accumulations.into_iter().enumerate().skip(1) {
let Some(mut values) = values else {
continue;
};
if let Some(prev_poly) = cur_poly {
let eval = SecureColumnByCoords {
columns: prev_poly.0.map(|c| {
c.evaluate_with_twiddles(
CanonicCoset::new(log_size as u32).circle_domain(),
&twiddles,
)
.values
}),
};
B::accumulate(&mut values, &eval);
}
cur_poly = Some(SecureCirclePoly(values.columns.map(|c| {
CircleEvaluation::<B, BaseField, BitReversedOrder>::new(
CanonicCoset::new(log_size as u32).circle_domain(),
c,
)
.interpolate_with_twiddles(&twiddles)
})));
}
cur_poly.unwrap_or_else(|| {
SecureCirclePoly(std::array::from_fn(|_| {
CirclePoly::new(Col::<B, BaseField>::zeros(1 << log_size))
}))
})
}
}
/// A domain accumulator for polynomials of a single size.
pub struct ColumnAccumulator<'a, B: Backend> {
pub random_coeff_powers: Vec<SecureField>,
pub col: &'a mut SecureColumnByCoords<B>,
}
impl<'a> ColumnAccumulator<'a, CpuBackend> {
pub fn accumulate(&mut self, index: usize, evaluation: SecureField) {
let val = self.col.at(index) + evaluation;
self.col.set(index, val);
}
}
#[cfg(test)]
mod tests {
use std::array;
use num_traits::Zero;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use super::*;
use crate::core::backend::cpu::CpuCircleEvaluation;
use crate::core::circle::CirclePoint;
use crate::core::fields::m31::{M31, P};
use crate::qm31;
#[test]
fn test_point_evaluation_accumulator() {
// Generate a vector of random sizes with a constant seed.
let mut rng = SmallRng::seed_from_u64(0);
const MAX_LOG_SIZE: u32 = 10;
const MASK: u32 = P;
let log_sizes = (0..100)
.map(|_| rng.gen_range(4..MAX_LOG_SIZE))
.collect::<Vec<_>>();
// Generate random evaluations.
let evaluations = log_sizes
.iter()
.map(|_| M31::from_u32_unchecked(rng.gen::<u32>() & MASK))
.collect::<Vec<_>>();
let alpha = qm31!(2, 3, 4, 5);
// Use accumulator.
let mut accumulator = PointEvaluationAccumulator::new(alpha);
for (_, evaluation) in log_sizes.iter().zip(evaluations.iter()) {
accumulator.accumulate((*evaluation).into());
}
let accumulator_res = accumulator.finalize();
// Use direct computation.
let mut res = SecureField::default();
for evaluation in evaluations.iter() {
res = res * alpha + *evaluation;
}
assert_eq!(accumulator_res, res);
}
#[test]
fn test_domain_evaluation_accumulator() {
// Generate a vector of random sizes with a constant seed.
let mut rng = SmallRng::seed_from_u64(0);
const LOG_SIZE_MIN: u32 = 4;
const LOG_SIZE_BOUND: u32 = 10;
const MASK: u32 = P;
let mut log_sizes = (0..100)
.map(|_| rng.gen_range(LOG_SIZE_MIN..LOG_SIZE_BOUND))
.collect::<Vec<_>>();
log_sizes.sort();
// Generate random evaluations.
let evaluations = log_sizes
.iter()
.map(|log_size| {
(0..(1 << *log_size))
.map(|_| M31::from_u32_unchecked(rng.gen::<u32>() & MASK))
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
let alpha = qm31!(2, 3, 4, 5);
// Use accumulator.
let mut accumulator = DomainEvaluationAccumulator::<CpuBackend>::new(
alpha,
LOG_SIZE_BOUND,
evaluations.len(),
);
let n_cols_per_size: [(u32, usize); (LOG_SIZE_BOUND - LOG_SIZE_MIN) as usize] =
array::from_fn(|i| {
let current_log_size = LOG_SIZE_MIN + i as u32;
let n_cols = log_sizes
.iter()
.copied()
.filter(|&log_size| log_size == current_log_size)
.count();
(current_log_size, n_cols)
});
let mut cols = accumulator.columns(n_cols_per_size);
let mut eval_chunk_offset = 0;
for (log_size, n_cols) in n_cols_per_size.iter() {
for index in 0..(1 << log_size) {
let mut val = SecureField::zero();
for (eval_index, (col_log_size, evaluation)) in
log_sizes.iter().zip(evaluations.iter()).enumerate()
{
if *log_size != *col_log_size {
continue;
}
// The random coefficient powers chunk is in regular order.
let random_coeff_chunk =
&cols[(log_size - LOG_SIZE_MIN) as usize].random_coeff_powers;
val += random_coeff_chunk
[random_coeff_chunk.len() - 1 - (eval_index - eval_chunk_offset)]
* evaluation[index];
}
cols[(log_size - LOG_SIZE_MIN) as usize].accumulate(index, val);
}
eval_chunk_offset += n_cols;
}
let accumulator_poly = accumulator.finalize();
// Pick an arbitrary sample point.
let point = CirclePoint::<SecureField>::get_point(98989892);
let accumulator_res = accumulator_poly.eval_at_point(point);
// Use direct computation.
let mut res = SecureField::default();
for (log_size, values) in log_sizes.into_iter().zip(evaluations) {
res = res * alpha
+ CpuCircleEvaluation::new(CanonicCoset::new(log_size).circle_domain(), values)
.interpolate()
.eval_at_point(point);
}
assert_eq!(accumulator_res, res);
}
}

View File

@ -0,0 +1,80 @@
use itertools::Itertools;
use super::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use super::{Component, ComponentProver, Trace};
use crate::core::backend::Backend;
use crate::core::circle::CirclePoint;
use crate::core::fields::qm31::SecureField;
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::SecureCirclePoly;
use crate::core::ColumnVec;
pub struct Components<'a>(pub Vec<&'a dyn Component>);
impl<'a> Components<'a> {
pub fn composition_log_degree_bound(&self) -> u32 {
self.0
.iter()
.map(|component| component.max_constraint_log_degree_bound())
.max()
.unwrap()
}
pub fn mask_points(
&self,
point: CirclePoint<SecureField>,
) -> TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>> {
TreeVec::concat_cols(self.0.iter().map(|component| component.mask_points(point)))
}
pub fn eval_composition_polynomial_at_point(
&self,
point: CirclePoint<SecureField>,
mask_values: &TreeVec<Vec<Vec<SecureField>>>,
random_coeff: SecureField,
) -> SecureField {
//accumulator for the random linear comination over powers of random_coeff
let mut evaluation_accumulator = PointEvaluationAccumulator::new(random_coeff);
for component in &self.0 {
component.evaluate_constraint_quotients_at_point(
point,
mask_values,
&mut evaluation_accumulator,
)
}
evaluation_accumulator.finalize()
}
pub fn column_log_sizes(&self) -> TreeVec<ColumnVec<u32>> {
TreeVec::concat_cols(
self.0
.iter()
.map(|component| component.trace_log_degree_bounds()),
)
}
}
pub struct ComponentProvers<'a, B: Backend>(pub Vec<&'a dyn ComponentProver<B>>);
impl<'a, B: Backend> ComponentProvers<'a, B> {
pub fn components(&self) -> Components<'_> {
Components(self.0.iter().map(|c| *c as &dyn Component).collect_vec())
}
pub fn compute_composition_polynomial(
&self,
random_coeff: SecureField,
trace: &Trace<'_, B>,
) -> SecureCirclePoly<B> {
let total_constraints: usize = self.0.iter().map(|c| c.n_constraints()).sum();
let mut accumulator = DomainEvaluationAccumulator::new(
random_coeff,
self.components().composition_log_degree_bound(),
total_constraints,
);
for component in &self.0 {
component.evaluate_constraint_quotients_on_domain(trace, &mut accumulator)
}
accumulator.finalize()
}
}

View File

@ -0,0 +1,91 @@
use std::collections::HashSet;
use std::vec;
use itertools::Itertools;
use crate::core::circle::CirclePoint;
use crate::core::fields::qm31::SecureField;
use crate::core::poly::circle::CanonicCoset;
use crate::core::ColumnVec;
/// Mask holds a vector with an entry for each column.
/// Each entry holds a list of mask items, which are the offsets of the mask at that column.
type Mask = ColumnVec<Vec<usize>>;
/// Returns the same point for each mask item.
/// Should be used where all the mask items has no shift from the constraint point.
pub fn fixed_mask_points(
mask: &Mask,
point: CirclePoint<SecureField>,
) -> ColumnVec<Vec<CirclePoint<SecureField>>> {
assert_eq!(
mask.iter()
.flat_map(|mask_entry| mask_entry.iter().collect::<HashSet<_>>())
.collect::<HashSet<&usize>>()
.into_iter()
.collect_vec(),
vec![&0]
);
mask.iter()
.map(|mask_entry| mask_entry.iter().map(|_| point).collect())
.collect()
}
/// For each mask item returns the point shifted by the domain initial point of the column.
/// Should be used where the mask items are shifted from the constraint point.
pub fn shifted_mask_points(
mask: &Mask,
domains: &[CanonicCoset],
point: CirclePoint<SecureField>,
) -> ColumnVec<Vec<CirclePoint<SecureField>>> {
mask.iter()
.zip(domains.iter())
.map(|(mask_entry, domain)| {
mask_entry
.iter()
.map(|mask_item| point + domain.at(*mask_item).into_ef())
.collect()
})
.collect()
}
#[cfg(test)]
mod tests {
use crate::core::air::mask::{fixed_mask_points, shifted_mask_points};
use crate::core::circle::CirclePoint;
use crate::core::poly::circle::CanonicCoset;
#[test]
fn test_mask_fixed_points() {
let mask = vec![vec![0], vec![0]];
let constraint_point = CirclePoint::get_point(1234);
let points = fixed_mask_points(&mask, constraint_point);
assert_eq!(points.len(), 2);
assert_eq!(points[0].len(), 1);
assert_eq!(points[1].len(), 1);
assert_eq!(points[0][0], constraint_point);
assert_eq!(points[1][0], constraint_point);
}
#[test]
fn test_mask_shifted_points() {
let mask = vec![vec![0, 1], vec![0, 1, 2]];
let constraint_point = CirclePoint::get_point(1234);
let domains = (0..mask.len() as u32)
.map(|i| CanonicCoset::new(7 + i))
.collect::<Vec<_>>();
let points = shifted_mask_points(&mask, &domains, constraint_point);
assert_eq!(points.len(), 2);
assert_eq!(points[0].len(), 2);
assert_eq!(points[1].len(), 3);
assert_eq!(points[0][0], constraint_point + domains[0].at(0).into_ef());
assert_eq!(points[0][1], constraint_point + domains[0].at(1).into_ef());
assert_eq!(points[1][0], constraint_point + domains[1].at(0).into_ef());
assert_eq!(points[1][1], constraint_point + domains[1].at(1).into_ef());
assert_eq!(points[1][2], constraint_point + domains[1].at(2).into_ef());
}
}

View File

@ -0,0 +1,76 @@
pub use components::{ComponentProvers, Components};
use self::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use super::backend::Backend;
use super::circle::CirclePoint;
use super::fields::m31::BaseField;
use super::fields::qm31::SecureField;
use super::pcs::TreeVec;
use super::poly::circle::{CircleEvaluation, CirclePoly};
use super::poly::BitReversedOrder;
use super::ColumnVec;
pub mod accumulation;
mod components;
pub mod mask;
/// Arithmetic Intermediate Representation (AIR).
/// An Air instance is assumed to already contain all the information needed to
/// evaluate the constraints.
/// For instance, all interaction elements are assumed to be present in it.
/// Therefore, an AIR is generated only after the initial trace commitment phase.
// TODO(spapini): consider renaming this struct.
pub trait Air {
fn components(&self) -> Vec<&dyn Component>;
}
pub trait AirProver<B: Backend>: Air {
fn component_provers(&self) -> Vec<&dyn ComponentProver<B>>;
}
/// A component is a set of trace columns of various sizes along with a set of
/// constraints on them.
pub trait Component {
fn n_constraints(&self) -> usize;
fn max_constraint_log_degree_bound(&self) -> u32;
/// Returns the degree bounds of each trace column. The returned TreeVec should be of size
/// `n_interaction_phases`.
fn trace_log_degree_bounds(&self) -> TreeVec<ColumnVec<u32>>;
/// Returns the mask points for each trace column. The returned TreeVec should be of size
/// `n_interaction_phases`.
fn mask_points(
&self,
point: CirclePoint<SecureField>,
) -> TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>>;
/// Evaluates the constraint quotients combination of the component at a point.
fn evaluate_constraint_quotients_at_point(
&self,
point: CirclePoint<SecureField>,
mask: &TreeVec<ColumnVec<Vec<SecureField>>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
);
}
pub trait ComponentProver<B: Backend>: Component {
/// Evaluates the constraint quotients of the component on the evaluation domain.
/// Accumulates quotients in `evaluation_accumulator`.
fn evaluate_constraint_quotients_on_domain(
&self,
trace: &Trace<'_, B>,
evaluation_accumulator: &mut DomainEvaluationAccumulator<B>,
);
}
/// The set of polynomials that make up the trace.
///
/// Each polynomial is stored both in a coefficients, and evaluations form (for efficiency)
pub struct Trace<'a, B: Backend> {
/// Polynomials for each column.
pub polys: TreeVec<ColumnVec<&'a CirclePoly<B>>>,
/// Evaluations for each column (evaluated on their commitment domains).
pub evals: TreeVec<ColumnVec<&'a CircleEvaluation<B, BaseField, BitReversedOrder>>>,
}

View File

@ -0,0 +1,12 @@
use super::CpuBackend;
use crate::core::air::accumulation::AccumulationOps;
use crate::core::fields::secure_column::SecureColumnByCoords;
impl AccumulationOps for CpuBackend {
fn accumulate(column: &mut SecureColumnByCoords<Self>, other: &SecureColumnByCoords<Self>) {
for i in 0..column.len() {
let res_coeff = column.at(i) + other.at(i);
column.set(i, res_coeff);
}
}
}

View File

@ -0,0 +1,24 @@
use itertools::Itertools;
use crate::core::backend::CpuBackend;
use crate::core::fields::m31::BaseField;
use crate::core::vcs::blake2_hash::Blake2sHash;
use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher;
use crate::core::vcs::ops::{MerkleHasher, MerkleOps};
impl MerkleOps<Blake2sMerkleHasher> for CpuBackend {
fn commit_on_layer(
log_size: u32,
prev_layer: Option<&Vec<Blake2sHash>>,
columns: &[&Vec<BaseField>],
) -> Vec<Blake2sHash> {
(0..(1 << log_size))
.map(|i| {
Blake2sMerkleHasher::hash_node(
prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])),
&columns.iter().map(|column| column[i]).collect_vec(),
)
})
.collect()
}
}

View File

@ -0,0 +1,376 @@
use num_traits::Zero;
use super::CpuBackend;
use crate::core::backend::{Col, ColumnOps};
use crate::core::circle::{CirclePoint, Coset};
use crate::core::fft::{butterfly, ibutterfly};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::{ExtensionOf, FieldExpOps};
use crate::core::poly::circle::{
CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly, PolyOps,
};
use crate::core::poly::twiddles::TwiddleTree;
use crate::core::poly::utils::{domain_line_twiddles_from_tree, fold};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::{bit_reverse, coset_order_to_circle_domain_order};
impl PolyOps for CpuBackend {
type Twiddles = Vec<BaseField>;
fn new_canonical_ordered(
coset: CanonicCoset,
values: Col<Self, BaseField>,
) -> CircleEvaluation<Self, BaseField, BitReversedOrder> {
let domain = coset.circle_domain();
assert_eq!(values.len(), domain.size());
let mut new_values = coset_order_to_circle_domain_order(&values);
CpuBackend::bit_reverse_column(&mut new_values);
CircleEvaluation::new(domain, new_values)
}
fn interpolate(
eval: CircleEvaluation<Self, BaseField, BitReversedOrder>,
twiddles: &TwiddleTree<Self>,
) -> CirclePoly<Self> {
assert!(eval.domain.half_coset.is_doubling_of(twiddles.root_coset));
let mut values = eval.values;
if eval.domain.log_size() == 1 {
let y = eval.domain.half_coset.initial.y;
let n = BaseField::from(2);
let yn_inv = (y * n).inverse();
let y_inv = yn_inv * n;
let n_inv = yn_inv * y;
let (mut v0, mut v1) = (values[0], values[1]);
ibutterfly(&mut v0, &mut v1, y_inv);
return CirclePoly::new(vec![v0 * n_inv, v1 * n_inv]);
}
if eval.domain.log_size() == 2 {
let CirclePoint { x, y } = eval.domain.half_coset.initial;
let n = BaseField::from(4);
let xyn_inv = (x * y * n).inverse();
let x_inv = xyn_inv * y * n;
let y_inv = xyn_inv * x * n;
let n_inv = xyn_inv * x * y;
let (mut v0, mut v1, mut v2, mut v3) = (values[0], values[1], values[2], values[3]);
ibutterfly(&mut v0, &mut v1, y_inv);
ibutterfly(&mut v2, &mut v3, -y_inv);
ibutterfly(&mut v0, &mut v2, x_inv);
ibutterfly(&mut v1, &mut v3, x_inv);
return CirclePoly::new(vec![v0 * n_inv, v1 * n_inv, v2 * n_inv, v3 * n_inv]);
}
let line_twiddles = domain_line_twiddles_from_tree(eval.domain, &twiddles.itwiddles);
let circle_twiddles = circle_twiddles_from_line_twiddles(line_twiddles[0]);
for (h, t) in circle_twiddles.enumerate() {
fft_layer_loop(&mut values, 0, h, t, ibutterfly);
}
for (layer, layer_twiddles) in line_twiddles.into_iter().enumerate() {
for (h, &t) in layer_twiddles.iter().enumerate() {
fft_layer_loop(&mut values, layer + 1, h, t, ibutterfly);
}
}
// Divide all values by 2^log_size.
let inv = BaseField::from_u32_unchecked(eval.domain.size() as u32).inverse();
for val in &mut values {
*val *= inv;
}
CirclePoly::new(values)
}
fn eval_at_point(poly: &CirclePoly<Self>, point: CirclePoint<SecureField>) -> SecureField {
if poly.log_size() == 0 {
return poly.coeffs[0].into();
}
let mut mappings = vec![point.y];
let mut x = point.x;
for _ in 1..poly.log_size() {
mappings.push(x);
x = CirclePoint::double_x(x);
}
mappings.reverse();
fold(&poly.coeffs, &mappings)
}
fn extend(poly: &CirclePoly<Self>, log_size: u32) -> CirclePoly<Self> {
assert!(log_size >= poly.log_size());
let mut coeffs = Vec::with_capacity(1 << log_size);
coeffs.extend_from_slice(&poly.coeffs);
coeffs.resize(1 << log_size, BaseField::zero());
CirclePoly::new(coeffs)
}
fn evaluate(
poly: &CirclePoly<Self>,
domain: CircleDomain,
twiddles: &TwiddleTree<Self>,
) -> CircleEvaluation<Self, BaseField, BitReversedOrder> {
assert!(domain.half_coset.is_doubling_of(twiddles.root_coset));
let mut values = poly.extend(domain.log_size()).coeffs;
if domain.log_size() == 1 {
let (mut v0, mut v1) = (values[0], values[1]);
butterfly(&mut v0, &mut v1, domain.half_coset.initial.y);
return CircleEvaluation::new(domain, vec![v0, v1]);
}
if domain.log_size() == 2 {
let (mut v0, mut v1, mut v2, mut v3) = (values[0], values[1], values[2], values[3]);
let CirclePoint { x, y } = domain.half_coset.initial;
butterfly(&mut v0, &mut v2, x);
butterfly(&mut v1, &mut v3, x);
butterfly(&mut v0, &mut v1, y);
butterfly(&mut v2, &mut v3, -y);
return CircleEvaluation::new(domain, vec![v0, v1, v2, v3]);
}
let line_twiddles = domain_line_twiddles_from_tree(domain, &twiddles.twiddles);
let circle_twiddles = circle_twiddles_from_line_twiddles(line_twiddles[0]);
for (layer, layer_twiddles) in line_twiddles.iter().enumerate().rev() {
for (h, &t) in layer_twiddles.iter().enumerate() {
fft_layer_loop(&mut values, layer + 1, h, t, butterfly);
}
}
for (h, t) in circle_twiddles.enumerate() {
fft_layer_loop(&mut values, 0, h, t, butterfly);
}
CircleEvaluation::new(domain, values)
}
fn precompute_twiddles(mut coset: Coset) -> TwiddleTree<Self> {
const CHUNK_LOG_SIZE: usize = 12;
const CHUNK_SIZE: usize = 1 << CHUNK_LOG_SIZE;
let root_coset = coset;
let mut twiddles = Vec::with_capacity(coset.size());
for _ in 0..coset.log_size() {
let i0 = twiddles.len();
twiddles.extend(
coset
.iter()
.take(coset.size() / 2)
.map(|p| p.x)
.collect::<Vec<_>>(),
);
bit_reverse(&mut twiddles[i0..]);
coset = coset.double();
}
twiddles.push(1.into());
// Inverse twiddles.
// Fallback to the non-chunked version if the domain is not big enough.
if CHUNK_SIZE > coset.size() {
let itwiddles = twiddles.iter().map(|&t| t.inverse()).collect();
return TwiddleTree {
root_coset,
twiddles,
itwiddles,
};
}
let mut itwiddles = vec![BaseField::zero(); twiddles.len()];
twiddles
.array_chunks::<CHUNK_SIZE>()
.zip(itwiddles.array_chunks_mut::<CHUNK_SIZE>())
.for_each(|(src, dst)| {
BaseField::batch_inverse(src, dst);
});
TwiddleTree {
root_coset,
twiddles,
itwiddles,
}
}
}
fn fft_layer_loop(
values: &mut [BaseField],
i: usize,
h: usize,
t: BaseField,
butterfly_fn: impl Fn(&mut BaseField, &mut BaseField, BaseField),
) {
for l in 0..(1 << i) {
let idx0 = (h << (i + 1)) + l;
let idx1 = idx0 + (1 << i);
let (mut val0, mut val1) = (values[idx0], values[idx1]);
butterfly_fn(&mut val0, &mut val1, t);
(values[idx0], values[idx1]) = (val0, val1);
}
}
/// Computes the circle twiddles layer (layer 0) from the first line twiddles layer (layer 1).
///
/// Only works for line twiddles generated from a domain with size `>4`.
fn circle_twiddles_from_line_twiddles(
first_line_twiddles: &[BaseField],
) -> impl Iterator<Item = BaseField> + '_ {
// The twiddles for layer 0 can be computed from the twiddles for layer 1.
// Since the twiddles are bit reversed, we consider the circle domain in bit reversed order.
// Each consecutive 4 points in the bit reversed order of a coset form a circle coset of size 4.
// A circle coset of size 4 in bit reversed order looks like this:
// [(x, y), (-x, -y), (y, -x), (-y, x)]
// Note: This relation is derived from the fact that `M31_CIRCLE_GEN`.repeated_double(ORDER / 4)
// == (-1,0), and not (0,1). (0,1) would yield another relation.
// The twiddles for layer 0 are the y coordinates:
// [y, -y, -x, x]
// The twiddles for layer 1 in bit reversed order are the x coordinates of the even indices
// points:
// [x, y]
// Works also for inverse of the twiddles.
first_line_twiddles
.iter()
.array_chunks()
.flat_map(|[&x, &y]| [y, -y, -x, x])
}
impl<F: ExtensionOf<BaseField>, EvalOrder> IntoIterator
for CircleEvaluation<CpuBackend, F, EvalOrder>
{
type Item = F;
type IntoIter = std::vec::IntoIter<F>;
/// Creates a consuming iterator over the evaluations.
///
/// Evaluations are returned in the same order as elements of the domain.
fn into_iter(self) -> Self::IntoIter {
self.values.into_iter()
}
}
#[cfg(test)]
mod tests {
use std::iter::zip;
use num_traits::One;
use crate::core::backend::cpu::CpuCirclePoly;
use crate::core::circle::CirclePoint;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::poly::circle::CanonicCoset;
#[test]
fn test_eval_at_point_with_4_coeffs() {
// Represents the polynomial `1 + 2y + 3x + 4xy`.
// Note coefficients are passed in bit reversed order.
let poly = CpuCirclePoly::new([1, 3, 2, 4].map(BaseField::from).to_vec());
let x = BaseField::from(5).into();
let y = BaseField::from(8).into();
let eval = poly.eval_at_point(CirclePoint { x, y });
assert_eq!(
eval,
poly.coeffs[0] + poly.coeffs[1] * y + poly.coeffs[2] * x + poly.coeffs[3] * x * y
);
}
#[test]
fn test_eval_at_point_with_2_coeffs() {
// Represents the polynomial `1 + 2y`.
let poly = CpuCirclePoly::new(vec![BaseField::from(1), BaseField::from(2)]);
let x = BaseField::from(5).into();
let y = BaseField::from(8).into();
let eval = poly.eval_at_point(CirclePoint { x, y });
assert_eq!(eval, poly.coeffs[0] + poly.coeffs[1] * y);
}
#[test]
fn test_eval_at_point_with_1_coeff() {
// Represents the polynomial `1`.
let poly = CpuCirclePoly::new(vec![BaseField::one()]);
let x = BaseField::from(5).into();
let y = BaseField::from(8).into();
let eval = poly.eval_at_point(CirclePoint { x, y });
assert_eq!(eval, SecureField::one());
}
#[test]
fn test_evaluate_2_coeffs() {
let domain = CanonicCoset::new(1).circle_domain();
let poly = CpuCirclePoly::new((1..=2).map(BaseField::from).collect());
let evaluation = poly.clone().evaluate(domain).bit_reverse();
for (i, (p, eval)) in zip(domain, evaluation).enumerate() {
let eval: SecureField = eval.into();
assert_eq!(eval, poly.eval_at_point(p.into_ef()), "mismatch at i={i}");
}
}
#[test]
fn test_evaluate_4_coeffs() {
let domain = CanonicCoset::new(2).circle_domain();
let poly = CpuCirclePoly::new((1..=4).map(BaseField::from).collect());
let evaluation = poly.clone().evaluate(domain).bit_reverse();
for (i, (x, eval)) in zip(domain, evaluation).enumerate() {
let eval: SecureField = eval.into();
assert_eq!(eval, poly.eval_at_point(x.into_ef()), "mismatch at i={i}");
}
}
#[test]
fn test_evaluate_8_coeffs() {
let domain = CanonicCoset::new(3).circle_domain();
let poly = CpuCirclePoly::new((1..=8).map(BaseField::from).collect());
let evaluation = poly.clone().evaluate(domain).bit_reverse();
for (i, (x, eval)) in zip(domain, evaluation).enumerate() {
let eval: SecureField = eval.into();
assert_eq!(eval, poly.eval_at_point(x.into_ef()), "mismatch at i={i}");
}
}
#[test]
fn test_interpolate_2_evals() {
let poly = CpuCirclePoly::new(vec![BaseField::one(), BaseField::from(2)]);
let domain = CanonicCoset::new(1).circle_domain();
let evals = poly.clone().evaluate(domain);
let interpolated_poly = evals.interpolate();
assert_eq!(interpolated_poly.coeffs, poly.coeffs);
}
#[test]
fn test_interpolate_4_evals() {
let poly = CpuCirclePoly::new((1..=4).map(BaseField::from).collect());
let domain = CanonicCoset::new(2).circle_domain();
let evals = poly.clone().evaluate(domain);
let interpolated_poly = evals.interpolate();
assert_eq!(interpolated_poly.coeffs, poly.coeffs);
}
#[test]
fn test_interpolate_8_evals() {
let poly = CpuCirclePoly::new((1..=8).map(BaseField::from).collect());
let domain = CanonicCoset::new(3).circle_domain();
let evals = poly.clone().evaluate(domain);
let interpolated_poly = evals.interpolate();
assert_eq!(interpolated_poly.coeffs, poly.coeffs);
}
}

View File

@ -0,0 +1,144 @@
use super::CpuBackend;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumnByCoords;
use crate::core::fri::{fold_circle_into_line, fold_line, FriOps};
use crate::core::poly::circle::SecureEvaluation;
use crate::core::poly::line::LineEvaluation;
use crate::core::poly::twiddles::TwiddleTree;
use crate::core::poly::BitReversedOrder;
// TODO(spapini): Optimized these functions as well.
impl FriOps for CpuBackend {
fn fold_line(
eval: &LineEvaluation<Self>,
alpha: SecureField,
_twiddles: &TwiddleTree<Self>,
) -> LineEvaluation<Self> {
fold_line(eval, alpha)
}
fn fold_circle_into_line(
dst: &mut LineEvaluation<Self>,
src: &SecureEvaluation<Self, BitReversedOrder>,
alpha: SecureField,
_twiddles: &TwiddleTree<Self>,
) {
fold_circle_into_line(dst, src, alpha)
}
fn decompose(
eval: &SecureEvaluation<Self, BitReversedOrder>,
) -> (SecureEvaluation<Self, BitReversedOrder>, SecureField) {
let lambda = Self::decomposition_coefficient(eval);
let mut g_values = unsafe { SecureColumnByCoords::<Self>::uninitialized(eval.len()) };
let domain_size = eval.len();
let half_domain_size = domain_size / 2;
for i in 0..half_domain_size {
let x = eval.values.at(i);
let val = x - lambda;
g_values.set(i, val);
}
for i in half_domain_size..domain_size {
let x = eval.values.at(i);
let val = x + lambda;
g_values.set(i, val);
}
let g = SecureEvaluation::new(eval.domain, g_values);
(g, lambda)
}
}
impl CpuBackend {
/// Used to decompose a general polynomial to a polynomial inside the fft-space, and
/// the remainder terms.
/// A coset-diff on a [`CirclePoly`] that is in the FFT space will return zero.
///
/// Let N be the domain size, Let h be a coset size N/2. Using lemma #7 from the CircleStark
/// paper, <f,V_h> = lambda<V_h,V_h> = lambda\*N => lambda = f(0)\*V_h(0) + f(1)*V_h(1) + .. +
/// f(N-1)\*V_h(N-1). The Vanishing polynomial of a cannonic coset sized half the circle
/// domain,evaluated on the circle domain, is [(1, -1, -1, 1)] repeating. This becomes
/// alternating [+-1] in our NaturalOrder, and [(+, +, +, ... , -, -)] in bit reverse.
/// Explicitly, lambda\*N = sum(+f(0..N/2)) + sum(-f(N/2..)).
///
/// # Warning
/// This function assumes the blowupfactor is 2
///
/// [`CirclePoly`]: crate::core::poly::circle::CirclePoly
fn decomposition_coefficient(eval: &SecureEvaluation<Self, BitReversedOrder>) -> SecureField {
let domain_size = 1 << eval.domain.log_size();
let half_domain_size = domain_size / 2;
// eval is in bit-reverse, hence all the positive factors are in the first half, opposite to
// the latter.
let a_sum = (0..half_domain_size)
.map(|i| eval.values.at(i))
.sum::<SecureField>();
let b_sum = (half_domain_size..domain_size)
.map(|i| eval.values.at(i))
.sum::<SecureField>();
// lambda = sum(+-f(p)) / 2N.
(a_sum - b_sum) / BaseField::from_u32_unchecked(domain_size as u32)
}
}
#[cfg(test)]
mod tests {
use num_traits::Zero;
use crate::core::backend::cpu::{CpuCircleEvaluation, CpuCirclePoly};
use crate::core::backend::CpuBackend;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumnByCoords;
use crate::core::fri::FriOps;
use crate::core::poly::circle::{CanonicCoset, SecureEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::m31;
#[test]
fn decompose_coeff_out_fft_space_test() {
for domain_log_size in 5..12 {
let domain_log_half_size = domain_log_size - 1;
let s = CanonicCoset::new(domain_log_size);
let domain = s.circle_domain();
let mut coeffs = vec![BaseField::zero(); 1 << domain_log_size];
// Polynomial is out of FFT space.
coeffs[1 << domain_log_half_size] = m31!(1);
assert!(!CpuCirclePoly::new(coeffs.clone()).is_in_fft_space(domain_log_half_size));
let poly = CpuCirclePoly::new(coeffs);
let values = poly.evaluate(domain);
let secure_column = SecureColumnByCoords {
columns: [
values.values.clone(),
values.values.clone(),
values.values.clone(),
values.values.clone(),
],
};
let secure_eval = SecureEvaluation::<CpuBackend, BitReversedOrder>::new(
domain,
secure_column.clone(),
);
let (g, lambda) = CpuBackend::decompose(&secure_eval);
// Sanity check.
assert_ne!(lambda, SecureField::zero());
// Assert the new polynomial is in the FFT space.
for i in 0..4 {
let basefield_column = g.columns[i].clone();
let eval = CpuCircleEvaluation::new(domain, basefield_column);
let coeffs = eval.interpolate().coeffs;
assert!(CpuCirclePoly::new(coeffs).is_in_fft_space(domain_log_half_size));
}
}
}
}

View File

@ -0,0 +1,18 @@
use super::CpuBackend;
use crate::core::channel::Channel;
use crate::core::proof_of_work::GrindOps;
impl<C: Channel> GrindOps<C> for CpuBackend {
fn grind(channel: &C, pow_bits: u32) -> u64 {
// TODO(spapini): This is a naive implementation. Optimize it.
let mut nonce = 0;
loop {
let mut channel = channel.clone();
channel.mix_nonce(nonce);
if channel.trailing_zeros() >= pow_bits {
return nonce;
}
nonce += 1;
}
}
}

View File

@ -0,0 +1,448 @@
use std::ops::Index;
use num_traits::{One, Zero};
use crate::core::backend::CpuBackend;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::{ExtensionOf, Field};
use crate::core::lookups::gkr_prover::{
correct_sum_as_poly_in_first_variable, EqEvals, GkrMultivariatePolyOracle, GkrOps, Layer,
};
use crate::core::lookups::mle::{Mle, MleOps};
use crate::core::lookups::sumcheck::MultivariatePolyOracle;
use crate::core::lookups::utils::{Fraction, Reciprocal, UnivariatePoly};
impl GkrOps for CpuBackend {
fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Mle<Self, SecureField> {
Mle::new(gen_eq_evals(y, v))
}
fn next_layer(layer: &Layer<Self>) -> Layer<Self> {
match layer {
Layer::GrandProduct(layer) => next_grand_product_layer(layer),
Layer::LogUpGeneric {
numerators,
denominators,
} => next_logup_layer(MleExpr::Mle(numerators), denominators),
Layer::LogUpMultiplicities {
numerators,
denominators,
} => next_logup_layer(MleExpr::Mle(numerators), denominators),
Layer::LogUpSingles { denominators } => {
next_logup_layer(MleExpr::Constant(BaseField::one()), denominators)
}
}
}
fn sum_as_poly_in_first_variable(
h: &GkrMultivariatePolyOracle<'_, Self>,
claim: SecureField,
) -> UnivariatePoly<SecureField> {
let n_variables = h.n_variables();
assert!(!n_variables.is_zero());
let n_terms = 1 << (n_variables - 1);
let eq_evals = h.eq_evals.as_ref();
// Vector used to generate evaluations of `eq(x, y)` for `x` in the boolean hypercube.
let y = eq_evals.y();
let lambda = h.lambda;
let (mut eval_at_0, mut eval_at_2) = match &h.input_layer {
Layer::GrandProduct(col) => eval_grand_product_sum(eq_evals, col, n_terms),
Layer::LogUpGeneric {
numerators,
denominators,
} => eval_logup_sum(eq_evals, numerators, denominators, n_terms, lambda),
Layer::LogUpMultiplicities {
numerators,
denominators,
} => eval_logup_sum(eq_evals, numerators, denominators, n_terms, lambda),
Layer::LogUpSingles { denominators } => {
eval_logup_singles_sum(eq_evals, denominators, n_terms, lambda)
}
};
eval_at_0 *= h.eq_fixed_var_correction;
eval_at_2 *= h.eq_fixed_var_correction;
correct_sum_as_poly_in_first_variable(eval_at_0, eval_at_2, claim, y, n_variables)
}
}
/// Evaluates `sum_x eq(({0}^|r|, 0, x), y) * inp(r, t, x, 0) * inp(r, t, x, 1)` at `t=0` and `t=2`.
///
/// Output of the form: `(eval_at_0, eval_at_2)`.
fn eval_grand_product_sum(
eq_evals: &EqEvals<CpuBackend>,
input_layer: &Mle<CpuBackend, SecureField>,
n_terms: usize,
) -> (SecureField, SecureField) {
let mut eval_at_0 = SecureField::zero();
let mut eval_at_2 = SecureField::zero();
for i in 0..n_terms {
// Input polynomial at points `(r, {0, 1, 2}, bits(i), {0, 1})`.
let inp_at_r0i0 = input_layer[i * 2];
let inp_at_r0i1 = input_layer[i * 2 + 1];
let inp_at_r1i0 = input_layer[(n_terms + i) * 2];
let inp_at_r1i1 = input_layer[(n_terms + i) * 2 + 1];
// Note `inp(r, t, x) = eq(t, 0) * inp(r, 0, x) + eq(t, 1) * inp(r, 1, x)`
// => `inp(r, 2, x) = 2 * inp(r, 1, x) - inp(r, 0, x)`
// TODO(andrew): Consider evaluation at `1/2` to save an addition operation since
// `inp(r, 1/2, x) = 1/2 * (inp(r, 1, x) + inp(r, 0, x))`. `1/2 * ...` can be factored
// outside the loop.
let inp_at_r2i0 = inp_at_r1i0.double() - inp_at_r0i0;
let inp_at_r2i1 = inp_at_r1i1.double() - inp_at_r0i1;
// Product polynomial `prod(x) = inp(x, 0) * inp(x, 1)` at points `(r, {0, 2}, bits(i))`.
let prod_at_r2i = inp_at_r2i0 * inp_at_r2i1;
let prod_at_r0i = inp_at_r0i0 * inp_at_r0i1;
let eq_eval_at_0i = eq_evals[i];
eval_at_0 += eq_eval_at_0i * prod_at_r0i;
eval_at_2 += eq_eval_at_0i * prod_at_r2i;
}
(eval_at_0, eval_at_2)
}
/// Evaluates `sum_x eq(({0}^|r|, 0, x), y) * (inp_numer(r, t, x, 0) * inp_denom(r, t, x, 1) +
/// inp_numer(r, t, x, 1) * inp_denom(r, t, x, 0) + lambda * inp_denom(r, t, x, 0) * inp_denom(r, t,
/// x, 1))` at `t=0` and `t=2`.
///
/// Output of the form: `(eval_at_0, eval_at_2)`.
fn eval_logup_sum<F: Field>(
eq_evals: &EqEvals<CpuBackend>,
input_numerators: &Mle<CpuBackend, F>,
input_denominators: &Mle<CpuBackend, SecureField>,
n_terms: usize,
lambda: SecureField,
) -> (SecureField, SecureField)
where
SecureField: ExtensionOf<F> + Field,
{
let mut eval_at_0 = SecureField::zero();
let mut eval_at_2 = SecureField::zero();
for i in 0..n_terms {
// Input polynomials at points `(r, {0, 1, 2}, bits(i), {0, 1})`.
let inp_numer_at_r0i0 = input_numerators[i * 2];
let inp_denom_at_r0i0 = input_denominators[i * 2];
let inp_numer_at_r0i1 = input_numerators[i * 2 + 1];
let inp_denom_at_r0i1 = input_denominators[i * 2 + 1];
let inp_numer_at_r1i0 = input_numerators[(n_terms + i) * 2];
let inp_denom_at_r1i0 = input_denominators[(n_terms + i) * 2];
let inp_numer_at_r1i1 = input_numerators[(n_terms + i) * 2 + 1];
let inp_denom_at_r1i1 = input_denominators[(n_terms + i) * 2 + 1];
// Note `inp_denom(r, t, x) = eq(t, 0) * inp_denom(r, 0, x) + eq(t, 1) * inp_denom(r, 1, x)`
// => `inp_denom(r, 2, x) = 2 * inp_denom(r, 1, x) - inp_denom(r, 0, x)`
let inp_numer_at_r2i0 = inp_numer_at_r1i0.double() - inp_numer_at_r0i0;
let inp_denom_at_r2i0 = inp_denom_at_r1i0.double() - inp_denom_at_r0i0;
let inp_numer_at_r2i1 = inp_numer_at_r1i1.double() - inp_numer_at_r0i1;
let inp_denom_at_r2i1 = inp_denom_at_r1i1.double() - inp_denom_at_r0i1;
// Fraction addition polynomials:
// - `numer(x) = inp_numer(x, 0) * inp_denom(x, 1) + inp_numer(x, 1) * inp_denom(x, 0)`
// - `denom(x) = inp_denom(x, 1) * inp_denom(x, 0)`
// at points `(r, {0, 2}, bits(i))`.
let Fraction {
numerator: numer_at_r0i,
denominator: denom_at_r0i,
} = Fraction::new(inp_numer_at_r0i0, inp_denom_at_r0i0)
+ Fraction::new(inp_numer_at_r0i1, inp_denom_at_r0i1);
let Fraction {
numerator: numer_at_r2i,
denominator: denom_at_r2i,
} = Fraction::new(inp_numer_at_r2i0, inp_denom_at_r2i0)
+ Fraction::new(inp_numer_at_r2i1, inp_denom_at_r2i1);
let eq_eval_at_0i = eq_evals[i];
eval_at_0 += eq_eval_at_0i * (numer_at_r0i + lambda * denom_at_r0i);
eval_at_2 += eq_eval_at_0i * (numer_at_r2i + lambda * denom_at_r2i);
}
(eval_at_0, eval_at_2)
}
/// Evaluates `sum_x eq(({0}^|r|, 0, x), y) * (inp_denom(r, t, x, 1) + inp_denom(r, t, x, 0) +
/// lambda * inp_denom(r, t, x, 0) * inp_denom(r, t, x, 1))` at `t=0` and `t=2`.
///
/// Output of the form: `(eval_at_0, eval_at_2)`.
fn eval_logup_singles_sum(
eq_evals: &EqEvals<CpuBackend>,
input_denominators: &Mle<CpuBackend, SecureField>,
n_terms: usize,
lambda: SecureField,
) -> (SecureField, SecureField) {
let mut eval_at_0 = SecureField::zero();
let mut eval_at_2 = SecureField::zero();
for i in 0..n_terms {
// Input polynomial at points `(r, {0, 1, 2}, bits(i), {0, 1})`.
let inp_denom_at_r0i0 = input_denominators[i * 2];
let inp_denom_at_r0i1 = input_denominators[i * 2 + 1];
let inp_denom_at_r1i0 = input_denominators[(n_terms + i) * 2];
let inp_denom_at_r1i1 = input_denominators[(n_terms + i) * 2 + 1];
// Note `inp_denom(r, t, x) = eq(t, 0) * inp_denom(r, 0, x) + eq(t, 1) * inp_denom(r, 1, x)`
// => `inp_denom(r, 2, x) = 2 * inp_denom(r, 1, x) - inp_denom(r, 0, x)`
let inp_denom_at_r2i0 = inp_denom_at_r1i0.double() - inp_denom_at_r0i0;
let inp_denom_at_r2i1 = inp_denom_at_r1i1.double() - inp_denom_at_r0i1;
// Fraction addition polynomials at points:
// - `numer(x) = inp_denom(x, 1) + inp_denom(x, 0)`
// - `denom(x) = inp_denom(x, 1) * inp_denom(x, 0)`
// at points `(r, {0, 2}, bits(i))`.
let Fraction {
numerator: numer_at_r0i,
denominator: denom_at_r0i,
} = Reciprocal::new(inp_denom_at_r0i0) + Reciprocal::new(inp_denom_at_r0i1);
let Fraction {
numerator: numer_at_r2i,
denominator: denom_at_r2i,
} = Reciprocal::new(inp_denom_at_r2i0) + Reciprocal::new(inp_denom_at_r2i1);
let eq_eval_at_0i = eq_evals[i];
eval_at_0 += eq_eval_at_0i * (numer_at_r0i + lambda * denom_at_r0i);
eval_at_2 += eq_eval_at_0i * (numer_at_r2i + lambda * denom_at_r2i);
}
(eval_at_0, eval_at_2)
}
/// Returns evaluations `eq(x, y) * v` for all `x` in `{0, 1}^n`.
///
/// Evaluations are returned in bit-reversed order.
pub fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Vec<SecureField> {
let mut evals = Vec::with_capacity(1 << y.len());
evals.push(v);
for &y_i in y.iter().rev() {
for j in 0..evals.len() {
// `lhs[j] = eq(0, y_i) * c[i]`
// `rhs[j] = eq(1, y_i) * c[i]`
let tmp = evals[j] * y_i;
evals.push(tmp);
evals[j] -= tmp;
}
}
evals
}
fn next_grand_product_layer(layer: &Mle<CpuBackend, SecureField>) -> Layer<CpuBackend> {
let res = layer.array_chunks().map(|&[a, b]| a * b).collect();
Layer::GrandProduct(Mle::new(res))
}
fn next_logup_layer<F>(
numerators: MleExpr<'_, F>,
denominators: &Mle<CpuBackend, SecureField>,
) -> Layer<CpuBackend>
where
F: Field,
SecureField: ExtensionOf<F>,
CpuBackend: MleOps<F>,
{
let half_n = 1 << (denominators.n_variables() - 1);
let mut next_numerators = Vec::with_capacity(half_n);
let mut next_denominators = Vec::with_capacity(half_n);
for i in 0..half_n {
let a = Fraction::new(numerators[i * 2], denominators[i * 2]);
let b = Fraction::new(numerators[i * 2 + 1], denominators[i * 2 + 1]);
let res = a + b;
next_numerators.push(res.numerator);
next_denominators.push(res.denominator);
}
Layer::LogUpGeneric {
numerators: Mle::new(next_numerators),
denominators: Mle::new(next_denominators),
}
}
enum MleExpr<'a, F: Field> {
Constant(F),
Mle(&'a Mle<CpuBackend, F>),
}
impl<'a, F: Field> Index<usize> for MleExpr<'a, F> {
type Output = F;
fn index(&self, index: usize) -> &F {
match self {
Self::Constant(v) => v,
Self::Mle(mle) => &mle[index],
}
}
}
#[cfg(test)]
mod tests {
use std::iter::zip;
use num_traits::{One, Zero};
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use crate::core::backend::CpuBackend;
use crate::core::channel::Channel;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::gkr_prover::{prove_batch, GkrOps, Layer};
use crate::core::lookups::gkr_verifier::{partially_verify_batch, Gate, GkrArtifact, GkrError};
use crate::core::lookups::mle::Mle;
use crate::core::lookups::utils::{eq, Fraction};
use crate::core::test_utils::test_channel;
#[test]
fn gen_eq_evals() {
let zero = SecureField::zero();
let one = SecureField::one();
let two = BaseField::from(2).into();
let y = [7, 3].map(|v| BaseField::from(v).into());
let eq_evals = CpuBackend::gen_eq_evals(&y, two);
assert_eq!(
*eq_evals,
[
eq(&[zero, zero], &y) * two,
eq(&[zero, one], &y) * two,
eq(&[one, zero], &y) * two,
eq(&[one, one], &y) * two,
]
);
}
#[test]
fn grand_product_works() -> Result<(), GkrError> {
const N: usize = 1 << 5;
let values = test_channel().draw_felts(N);
let product = values.iter().product::<SecureField>();
let col = Mle::<CpuBackend, SecureField>::new(values);
let input_layer = Layer::GrandProduct(col.clone());
let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]);
let GkrArtifact {
ood_point: r,
claims_to_verify_by_instance,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut test_channel())?;
assert_eq!(proof.output_claims_by_instance, [vec![product]]);
assert_eq!(claims_to_verify_by_instance, [vec![col.eval_at_point(&r)]]);
Ok(())
}
#[test]
fn logup_with_generic_trace_works() -> Result<(), GkrError> {
const N: usize = 1 << 5;
let mut rng = SmallRng::seed_from_u64(0);
let numerator_values = (0..N).map(|_| rng.gen()).collect::<Vec<SecureField>>();
let denominator_values = (0..N).map(|_| rng.gen()).collect::<Vec<SecureField>>();
let sum = zip(&numerator_values, &denominator_values)
.map(|(&n, &d)| Fraction::new(n, d))
.sum::<Fraction<SecureField, SecureField>>();
let numerators = Mle::<CpuBackend, SecureField>::new(numerator_values);
let denominators = Mle::<CpuBackend, SecureField>::new(denominator_values);
let top_layer = Layer::LogUpGeneric {
numerators: numerators.clone(),
denominators: denominators.clone(),
};
let (proof, _) = prove_batch(&mut test_channel(), vec![top_layer]);
let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?;
assert_eq!(claims_to_verify_by_instance.len(), 1);
assert_eq!(proof.output_claims_by_instance.len(), 1);
assert_eq!(
claims_to_verify_by_instance[0],
[
numerators.eval_at_point(&ood_point),
denominators.eval_at_point(&ood_point)
]
);
assert_eq!(
proof.output_claims_by_instance[0],
[sum.numerator, sum.denominator]
);
Ok(())
}
#[test]
fn logup_with_singles_trace_works() -> Result<(), GkrError> {
const N: usize = 1 << 5;
let mut rng = SmallRng::seed_from_u64(0);
let denominator_values = (0..N).map(|_| rng.gen()).collect::<Vec<SecureField>>();
let sum = denominator_values
.iter()
.map(|&d| Fraction::new(SecureField::one(), d))
.sum::<Fraction<SecureField, SecureField>>();
let denominators = Mle::<CpuBackend, SecureField>::new(denominator_values);
let top_layer = Layer::LogUpSingles {
denominators: denominators.clone(),
};
let (proof, _) = prove_batch(&mut test_channel(), vec![top_layer]);
let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?;
assert_eq!(claims_to_verify_by_instance.len(), 1);
assert_eq!(proof.output_claims_by_instance.len(), 1);
assert_eq!(
claims_to_verify_by_instance[0],
[SecureField::one(), denominators.eval_at_point(&ood_point)]
);
assert_eq!(
proof.output_claims_by_instance[0],
[sum.numerator, sum.denominator]
);
Ok(())
}
#[test]
fn logup_with_multiplicities_trace_works() -> Result<(), GkrError> {
const N: usize = 1 << 5;
let mut rng = SmallRng::seed_from_u64(0);
let numerator_values = (0..N).map(|_| rng.gen()).collect::<Vec<BaseField>>();
let denominator_values = (0..N).map(|_| rng.gen()).collect::<Vec<SecureField>>();
let sum = zip(&numerator_values, &denominator_values)
.map(|(&n, &d)| Fraction::new(n.into(), d))
.sum::<Fraction<SecureField, SecureField>>();
let numerators = Mle::<CpuBackend, BaseField>::new(numerator_values);
let denominators = Mle::<CpuBackend, SecureField>::new(denominator_values);
let top_layer = Layer::LogUpMultiplicities {
numerators: numerators.clone(),
denominators: denominators.clone(),
};
let (proof, _) = prove_batch(&mut test_channel(), vec![top_layer]);
let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?;
assert_eq!(claims_to_verify_by_instance.len(), 1);
assert_eq!(proof.output_claims_by_instance.len(), 1);
assert_eq!(
claims_to_verify_by_instance[0],
[
numerators.eval_at_point(&ood_point),
denominators.eval_at_point(&ood_point)
]
);
assert_eq!(
proof.output_claims_by_instance[0],
[sum.numerator, sum.denominator]
);
Ok(())
}
}

View File

@ -0,0 +1,66 @@
use std::iter::zip;
use num_traits::{One, Zero};
use crate::core::backend::CpuBackend;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::mle::{Mle, MleOps};
use crate::core::lookups::sumcheck::MultivariatePolyOracle;
use crate::core::lookups::utils::{fold_mle_evals, UnivariatePoly};
impl MleOps<BaseField> for CpuBackend {
fn fix_first_variable(
mle: Mle<Self, BaseField>,
assignment: SecureField,
) -> Mle<Self, SecureField> {
let midpoint = mle.len() / 2;
let (lhs_evals, rhs_evals) = mle.split_at(midpoint);
let res = zip(lhs_evals, rhs_evals)
.map(|(&lhs_eval, &rhs_eval)| fold_mle_evals(assignment, lhs_eval, rhs_eval))
.collect();
Mle::new(res)
}
}
impl MleOps<SecureField> for CpuBackend {
fn fix_first_variable(
mle: Mle<Self, SecureField>,
assignment: SecureField,
) -> Mle<Self, SecureField> {
let midpoint = mle.len() / 2;
let mut evals = mle.into_evals();
for i in 0..midpoint {
let lhs_eval = evals[i];
let rhs_eval = evals[i + midpoint];
evals[i] = fold_mle_evals(assignment, lhs_eval, rhs_eval);
}
evals.truncate(midpoint);
Mle::new(evals)
}
}
impl MultivariatePolyOracle for Mle<CpuBackend, SecureField> {
fn n_variables(&self) -> usize {
self.n_variables()
}
fn sum_as_poly_in_first_variable(&self, claim: SecureField) -> UnivariatePoly<SecureField> {
let x0 = SecureField::zero();
let x1 = SecureField::one();
let y0 = self[0..self.len() / 2].iter().sum();
let y1 = claim - y0;
UnivariatePoly::interpolate_lagrange(&[x0, x1], &[y0, y1])
}
fn fix_first_variable(self, challenge: SecureField) -> Self {
self.fix_first_variable(challenge)
}
}

View File

@ -0,0 +1,2 @@
pub mod gkr;
mod mle;

View File

@ -0,0 +1,105 @@
mod accumulation;
mod blake2s;
mod circle;
mod fri;
mod grind;
pub mod lookups;
#[cfg(not(target_arch = "wasm32"))]
mod poseidon252;
pub mod quotients;
#[cfg(not(target_arch = "wasm32"))]
mod poseidon_bls;
use std::fmt::Debug;
use serde::{Deserialize, Serialize};
use super::{Backend, BackendForChannel, Column, ColumnOps, FieldOps};
use crate::core::fields::Field;
use crate::core::lookups::mle::Mle;
use crate::core::poly::circle::{CircleEvaluation, CirclePoly};
use crate::core::utils::bit_reverse;
use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel;
#[cfg(not(target_arch = "wasm32"))]
use crate::core::vcs::poseidon252_merkle::Poseidon252MerkleChannel;
#[cfg(not(target_arch = "wasm32"))]
use crate::core::vcs::poseidon_bls_merkle::PoseidonBLSMerkleChannel;
#[derive(Copy, Clone, Debug, Deserialize, Serialize)]
pub struct CpuBackend;
impl Backend for CpuBackend {}
impl BackendForChannel<Blake2sMerkleChannel> for CpuBackend {}
#[cfg(not(target_arch = "wasm32"))]
impl BackendForChannel<Poseidon252MerkleChannel> for CpuBackend {}
#[cfg(not(target_arch = "wasm32"))]
impl BackendForChannel<PoseidonBLSMerkleChannel> for CpuBackend {}
impl<T: Debug + Clone + Default> ColumnOps<T> for CpuBackend {
type Column = Vec<T>;
fn bit_reverse_column(column: &mut Self::Column) {
bit_reverse(column)
}
}
impl<F: Field> FieldOps<F> for CpuBackend {
/// Batch inversion using the Montgomery's trick.
// TODO(Ohad): Benchmark this function.
fn batch_inverse(column: &Self::Column, dst: &mut Self::Column) {
F::batch_inverse(column, &mut dst[..]);
}
}
impl<T: Debug + Clone + Default> Column<T> for Vec<T> {
fn zeros(len: usize) -> Self {
vec![T::default(); len]
}
#[allow(clippy::uninit_vec)]
unsafe fn uninitialized(length: usize) -> Self {
let mut data = Vec::with_capacity(length);
data.set_len(length);
data
}
fn to_cpu(&self) -> Vec<T> {
self.clone()
}
fn len(&self) -> usize {
self.len()
}
fn at(&self, index: usize) -> T {
self[index].clone()
}
fn set(&mut self, index: usize, value: T) {
self[index] = value;
}
}
pub type CpuCirclePoly = CirclePoly<CpuBackend>;
pub type CpuCircleEvaluation<F, EvalOrder> = CircleEvaluation<CpuBackend, F, EvalOrder>;
pub type CpuMle<F> = Mle<CpuBackend, F>;
#[cfg(test)]
mod tests {
use itertools::Itertools;
use rand::prelude::*;
use rand::rngs::SmallRng;
use crate::core::backend::{Column, CpuBackend, FieldOps};
use crate::core::fields::qm31::QM31;
use crate::core::fields::FieldExpOps;
#[test]
fn batch_inverse_test() {
let mut rng = SmallRng::seed_from_u64(0);
let column = rng.gen::<[QM31; 16]>().to_vec();
let expected = column.iter().map(|e| e.inverse()).collect_vec();
let mut dst = Column::zeros(column.len());
CpuBackend::batch_inverse(&column, &mut dst);
assert_eq!(expected, dst);
}
}

View File

@ -0,0 +1,24 @@
use itertools::Itertools;
use starknet_ff::FieldElement as FieldElement252;
use super::CpuBackend;
use crate::core::fields::m31::BaseField;
use crate::core::vcs::ops::{MerkleHasher, MerkleOps};
use crate::core::vcs::poseidon252_merkle::Poseidon252MerkleHasher;
impl MerkleOps<Poseidon252MerkleHasher> for CpuBackend {
fn commit_on_layer(
log_size: u32,
prev_layer: Option<&Vec<FieldElement252>>,
columns: &[&Vec<BaseField>],
) -> Vec<FieldElement252> {
(0..(1 << log_size))
.map(|i| {
Poseidon252MerkleHasher::hash_node(
prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])),
&columns.iter().map(|column| column[i]).collect_vec(),
)
})
.collect()
}
}

View File

@ -0,0 +1,24 @@
use itertools::Itertools;
use ark_bls12_381::Fr as BlsFr;
use super::CpuBackend;
use crate::core::fields::m31::BaseField;
use crate::core::vcs::ops::{MerkleHasher, MerkleOps};
use crate::core::vcs::poseidon_bls_merkle::PoseidonBLSMerkleHasher;
impl MerkleOps<PoseidonBLSMerkleHasher> for CpuBackend {
fn commit_on_layer(
log_size: u32,
prev_layer: Option<&Vec<BlsFr>>,
columns: &[&Vec<BaseField>],
) -> Vec<BlsFr> {
(0..(1 << log_size))
.map(|i| {
PoseidonBLSMerkleHasher::hash_node(
prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])),
&columns.iter().map(|column| column[i]).collect_vec(),
)
})
.collect()
}
}

View File

@ -0,0 +1,210 @@
use itertools::{izip, zip_eq};
use num_traits::{One, Zero};
use super::CpuBackend;
use crate::core::circle::CirclePoint;
use crate::core::constraints::complex_conjugate_line_coeffs;
use crate::core::fields::cm31::CM31;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumnByCoords;
use crate::core::fields::FieldExpOps;
use crate::core::pcs::quotients::{ColumnSampleBatch, PointSample, QuotientOps};
use crate::core::poly::circle::{CircleDomain, CircleEvaluation, SecureEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::{bit_reverse, bit_reverse_index};
impl QuotientOps for CpuBackend {
fn accumulate_quotients(
domain: CircleDomain,
columns: &[&CircleEvaluation<Self, BaseField, BitReversedOrder>],
random_coeff: SecureField,
sample_batches: &[ColumnSampleBatch],
_log_blowup_factor: u32,
) -> SecureEvaluation<Self, BitReversedOrder> {
let mut values = unsafe { SecureColumnByCoords::uninitialized(domain.size()) };
let quotient_constants = quotient_constants(sample_batches, random_coeff, domain);
for row in 0..domain.size() {
let domain_point = domain.at(bit_reverse_index(row, domain.log_size()));
let row_value = accumulate_row_quotients(
sample_batches,
columns,
&quotient_constants,
row,
domain_point,
);
values.set(row, row_value);
}
SecureEvaluation::new(domain, values)
}
}
pub fn accumulate_row_quotients(
sample_batches: &[ColumnSampleBatch],
columns: &[&CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>],
quotient_constants: &QuotientConstants,
row: usize,
domain_point: CirclePoint<BaseField>,
) -> SecureField {
let mut row_accumulator = SecureField::zero();
for (sample_batch, line_coeffs, batch_coeff, denominator_inverses) in izip!(
sample_batches,
&quotient_constants.line_coeffs,
&quotient_constants.batch_random_coeffs,
&quotient_constants.denominator_inverses
) {
let mut numerator = SecureField::zero();
for ((column_index, _), (a, b, c)) in zip_eq(&sample_batch.columns_and_values, line_coeffs)
{
let column = &columns[*column_index];
let value = column[row] * *c;
// The numerator is a line equation passing through
// (sample_point.y, sample_value), (conj(sample_point), conj(sample_value))
// evaluated at (domain_point.y, value).
// When substituting a polynomial in this line equation, we get a polynomial with a root
// at sample_point and conj(sample_point) if the original polynomial had the values
// sample_value and conj(sample_value) at these points.
let linear_term = *a * domain_point.y + *b;
numerator += value - linear_term;
}
row_accumulator =
row_accumulator * *batch_coeff + numerator.mul_cm31(denominator_inverses[row]);
}
row_accumulator
}
/// Precompute the complex conjugate line coefficients for each column in each sample batch.
/// Specifically, for the i-th (in a sample batch) column's numerator term
/// `alpha^i * (c * F(p) - (a * p.y + b))`, we precompute and return the constants:
/// (`alpha^i * a`, `alpha^i * b`, `alpha^i * c`).
pub fn column_line_coeffs(
sample_batches: &[ColumnSampleBatch],
random_coeff: SecureField,
) -> Vec<Vec<(SecureField, SecureField, SecureField)>> {
sample_batches
.iter()
.map(|sample_batch| {
let mut alpha = SecureField::one();
sample_batch
.columns_and_values
.iter()
.map(|(_, sampled_value)| {
alpha *= random_coeff;
let sample = PointSample {
point: sample_batch.point,
value: *sampled_value,
};
complex_conjugate_line_coeffs(&sample, alpha)
})
.collect()
})
.collect()
}
/// Precompute the random coefficients used to linearly combine the batched quotients.
/// Specifically, for each sample batch we compute random_coeff^(number of columns in the batch),
/// which is used to linearly combine the batch with the next one.
pub fn batch_random_coeffs(
sample_batches: &[ColumnSampleBatch],
random_coeff: SecureField,
) -> Vec<SecureField> {
sample_batches
.iter()
.map(|sb| random_coeff.pow(sb.columns_and_values.len() as u128))
.collect()
}
fn denominator_inverses(
sample_batches: &[ColumnSampleBatch],
domain: CircleDomain,
) -> Vec<Vec<CM31>> {
let mut flat_denominators = Vec::with_capacity(sample_batches.len() * domain.size());
// We want a P to be on a line that passes through a point Pr + uPi in QM31^2, and its conjugate
// Pr - uPi. Thus, Pr - P is parallel to Pi. Or, (Pr - P).x * Pi.y - (Pr - P).y * Pi.x = 0.
for sample_batch in sample_batches {
// Extract Pr, Pi.
let prx = sample_batch.point.x.0;
let pry = sample_batch.point.y.0;
let pix = sample_batch.point.x.1;
let piy = sample_batch.point.y.1;
for row in 0..domain.size() {
let domain_point = domain.at(row);
flat_denominators.push((prx - domain_point.x) * piy - (pry - domain_point.y) * pix);
}
}
let mut flat_denominator_inverses = vec![CM31::zero(); flat_denominators.len()];
CM31::batch_inverse(&flat_denominators, &mut flat_denominator_inverses);
flat_denominator_inverses
.chunks_mut(domain.size())
.map(|denominator_inverses| {
bit_reverse(denominator_inverses);
denominator_inverses.to_vec()
})
.collect()
}
pub fn quotient_constants(
sample_batches: &[ColumnSampleBatch],
random_coeff: SecureField,
domain: CircleDomain,
) -> QuotientConstants {
let line_coeffs = column_line_coeffs(sample_batches, random_coeff);
let batch_random_coeffs = batch_random_coeffs(sample_batches, random_coeff);
let denominator_inverses = denominator_inverses(sample_batches, domain);
QuotientConstants {
line_coeffs,
batch_random_coeffs,
denominator_inverses,
}
}
/// Holds the precomputed constant values used in each quotient evaluation.
pub struct QuotientConstants {
/// The line coefficients for each quotient numerator term. For more details see
/// [self::column_line_coeffs].
pub line_coeffs: Vec<Vec<(SecureField, SecureField, SecureField)>>,
/// The random coefficients used to linearly combine the batched quotients For more details see
/// [self::batch_random_coeffs].
pub batch_random_coeffs: Vec<SecureField>,
/// The inverses of the denominators of the quotients.
pub denominator_inverses: Vec<Vec<CM31>>,
}
#[cfg(test)]
mod tests {
use crate::core::backend::cpu::{CpuCircleEvaluation, CpuCirclePoly};
use crate::core::backend::CpuBackend;
use crate::core::circle::SECURE_FIELD_CIRCLE_GEN;
use crate::core::pcs::quotients::{ColumnSampleBatch, QuotientOps};
use crate::core::poly::circle::CanonicCoset;
use crate::{m31, qm31};
#[test]
fn test_quotients_are_low_degree() {
const LOG_SIZE: u32 = 7;
const LOG_BLOWUP_FACTOR: u32 = 1;
let polynomial = CpuCirclePoly::new((0..1 << LOG_SIZE).map(|i| m31!(i)).collect());
let eval_domain = CanonicCoset::new(LOG_SIZE + 1).circle_domain();
let eval = polynomial.evaluate(eval_domain);
let point = SECURE_FIELD_CIRCLE_GEN;
let value = polynomial.eval_at_point(point);
let coeff = qm31!(1, 2, 3, 4);
let quot_eval = CpuBackend::accumulate_quotients(
eval_domain,
&[&eval],
coeff,
&[ColumnSampleBatch {
point,
columns_and_values: vec![(0, value)],
}],
LOG_BLOWUP_FACTOR,
);
let quot_poly_base_field =
CpuCircleEvaluation::new(eval_domain, quot_eval.columns[0].clone()).interpolate();
assert!(quot_poly_base_field.is_in_fri_space(LOG_SIZE));
}
}

View File

@ -0,0 +1,66 @@
use std::fmt::Debug;
pub use cpu::CpuBackend;
use super::air::accumulation::AccumulationOps;
use super::channel::MerkleChannel;
use super::fields::m31::BaseField;
use super::fields::qm31::SecureField;
use super::fields::FieldOps;
use super::fri::FriOps;
use super::lookups::gkr_prover::GkrOps;
use super::pcs::quotients::QuotientOps;
use super::poly::circle::PolyOps;
use super::proof_of_work::GrindOps;
use super::vcs::ops::MerkleOps;
pub mod cpu;
pub mod simd;
pub trait Backend:
Copy
+ Clone
+ Debug
+ FieldOps<BaseField>
+ FieldOps<SecureField>
+ PolyOps
+ QuotientOps
+ FriOps
+ AccumulationOps
+ GkrOps
{
}
pub trait BackendForChannel<MC: MerkleChannel>:
Backend + MerkleOps<MC::H> + GrindOps<MC::C>
{
}
pub trait ColumnOps<T> {
type Column: Column<T>;
fn bit_reverse_column(column: &mut Self::Column);
}
pub type Col<B, T> = <B as ColumnOps<T>>::Column;
// TODO(spapini): Consider removing the generic parameter and only support BaseField.
pub trait Column<T>: Clone + Debug + FromIterator<T> {
/// Creates a new column of zeros with the given length.
fn zeros(len: usize) -> Self;
/// Creates a new column of uninitialized values with the given length.
/// # Safety
/// The caller must ensure that the column is populated before being used.
unsafe fn uninitialized(len: usize) -> Self;
/// Returns a cpu vector of the column.
fn to_cpu(&self) -> Vec<T>;
/// Returns the length of the column.
fn len(&self) -> usize;
/// Returns true if the column is empty.
fn is_empty(&self) -> bool {
self.len() == 0
}
/// Retrieves the element at the given index.
fn at(&self, index: usize) -> T;
/// Sets the element at the given index.
fn set(&mut self, index: usize, value: T);
}

View File

@ -0,0 +1,12 @@
use super::SimdBackend;
use crate::core::air::accumulation::AccumulationOps;
use crate::core::fields::secure_column::SecureColumnByCoords;
impl AccumulationOps for SimdBackend {
fn accumulate(column: &mut SecureColumnByCoords<Self>, other: &SecureColumnByCoords<Self>) {
for i in 0..column.packed_len() {
let res_coeff = unsafe { column.packed_at(i) + other.packed_at(i) };
unsafe { column.set_packed(i, res_coeff) };
}
}
}

View File

@ -0,0 +1,203 @@
use std::array;
use super::column::{BaseColumn, SecureColumn};
use super::m31::PackedBaseField;
use super::SimdBackend;
use crate::core::backend::ColumnOps;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::utils::{bit_reverse as cpu_bit_reverse, bit_reverse_index};
const VEC_BITS: u32 = 4;
const W_BITS: u32 = 3;
pub const MIN_LOG_SIZE: u32 = 2 * W_BITS + VEC_BITS;
impl ColumnOps<BaseField> for SimdBackend {
type Column = BaseColumn;
fn bit_reverse_column(column: &mut Self::Column) {
// Fallback to cpu bit_reverse.
if column.data.len().ilog2() < MIN_LOG_SIZE {
cpu_bit_reverse(column.as_mut_slice());
return;
}
bit_reverse_m31(&mut column.data);
}
}
impl ColumnOps<SecureField> for SimdBackend {
type Column = SecureColumn;
fn bit_reverse_column(_column: &mut SecureColumn) {
todo!()
}
}
/// Bit reverses M31 values.
///
/// Given an array `A[0..2^n)`, computes `B[i] = A[bit_reverse(i)]`.
pub fn bit_reverse_m31(data: &mut [PackedBaseField]) {
assert!(data.len().is_power_of_two());
assert!(data.len().ilog2() >= MIN_LOG_SIZE);
// Indices in the array are of the form v_h w_h a w_l v_l, with
// |v_h| = |v_l| = VEC_BITS, |w_h| = |w_l| = W_BITS, |a| = n - 2*W_BITS - VEC_BITS.
// The loops go over a, w_l, w_h, and then swaps the 16 by 16 values at:
// * w_h a w_l * <-> * rev(w_h a w_l) *.
// These are 1 or 2 chunks of 2^W_BITS contiguous `u32x16` vectors.
let log_size = data.len().ilog2();
let a_bits = log_size - 2 * W_BITS - VEC_BITS;
// TODO(spapini): when doing multithreading, do it over a.
for a in 0u32..1 << a_bits {
for w_l in 0u32..1 << W_BITS {
let w_l_rev = w_l.reverse_bits() >> (u32::BITS - W_BITS);
for w_h in 0..w_l_rev + 1 {
let idx = ((((w_h << a_bits) | a) << W_BITS) | w_l) as usize;
let idx_rev = bit_reverse_index(idx, log_size - VEC_BITS);
// In order to not swap twice, only swap if idx <= idx_rev.
if idx > idx_rev {
continue;
}
// Read first chunk.
// TODO(spapini): Think about optimizing a_bits.
let chunk0 = array::from_fn(|i| unsafe {
*data.get_unchecked(idx + (i << (2 * W_BITS + a_bits)))
});
let values0 = bit_reverse16(chunk0);
if idx == idx_rev {
// Palindrome index. Write into the same chunk.
#[allow(clippy::needless_range_loop)]
for i in 0..16 {
unsafe {
*data.get_unchecked_mut(idx + (i << (2 * W_BITS + a_bits))) =
values0[i];
}
}
continue;
}
// Read bit reversed chunk.
let chunk1 = array::from_fn(|i| unsafe {
*data.get_unchecked(idx_rev + (i << (2 * W_BITS + a_bits)))
});
let values1 = bit_reverse16(chunk1);
for i in 0..16 {
unsafe {
*data.get_unchecked_mut(idx + (i << (2 * W_BITS + a_bits))) = values1[i];
*data.get_unchecked_mut(idx_rev + (i << (2 * W_BITS + a_bits))) =
values0[i];
}
}
}
}
}
}
/// Bit reverses 256 M31 values, packed in 16 words of 16 elements each.
fn bit_reverse16(mut data: [PackedBaseField; 16]) -> [PackedBaseField; 16] {
// Denote the index of each element in the 16 packed M31 words as abcd:0123,
// where abcd is the index of the packed word and 0123 is the index of the element in the word.
// Bit reversal is achieved by applying the following permutation to the index for 4 times:
// abcd:0123 => 0abc:123d
// This is how it looks like at each iteration.
// abcd:0123
// 0abc:123d
// 10ab:23dc
// 210a:3dcb
// 3210:dcba
for _ in 0..4 {
// Apply the abcd:0123 => 0abc:123d permutation.
// `interleave` allows us to interleave the first half of 2 words.
// For example, the second call interleaves 0010:0xyz (low half of register 2) with
// 0011:0xyz (low half of register 3), and stores the result in register 1 (0001).
// This results in
// 0001:xyz0 (even indices of register 1) <= 0010:0xyz (low half of register2), and
// 0001:xyz1 (odd indices of register 1) <= 0011:0xyz (low half of register 3)
// or 0001:xyzw <= 001w:0xyz.
let (d0, d8) = data[0].interleave(data[1]);
let (d1, d9) = data[2].interleave(data[3]);
let (d2, d10) = data[4].interleave(data[5]);
let (d3, d11) = data[6].interleave(data[7]);
let (d4, d12) = data[8].interleave(data[9]);
let (d5, d13) = data[10].interleave(data[11]);
let (d6, d14) = data[12].interleave(data[13]);
let (d7, d15) = data[14].interleave(data[15]);
data = [
d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, d14, d15,
];
}
data
}
#[cfg(test)]
mod tests {
use itertools::Itertools;
use super::{bit_reverse16, bit_reverse_m31, MIN_LOG_SIZE};
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::simd::m31::{PackedM31, N_LANES};
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Column, ColumnOps};
use crate::core::fields::m31::BaseField;
use crate::core::utils::bit_reverse as cpu_bit_reverse;
#[test]
fn test_bit_reverse16() {
let values: BaseColumn = (0..N_LANES * 16).map(BaseField::from).collect();
let mut expected = values.to_cpu();
cpu_bit_reverse(&mut expected);
let res = bit_reverse16(values.data.try_into().unwrap());
assert_eq!(res.map(PackedM31::to_array).flatten(), expected);
}
#[test]
fn bit_reverse_m31_works() {
const SIZE: usize = 1 << 15;
let data: Vec<_> = (0..SIZE).map(BaseField::from).collect();
let mut expected = data.clone();
cpu_bit_reverse(&mut expected);
let mut res: BaseColumn = data.into_iter().collect();
bit_reverse_m31(&mut res.data[..]);
assert_eq!(res.to_cpu(), expected);
}
#[test]
fn bit_reverse_small_column_works() {
const LOG_SIZE: u32 = MIN_LOG_SIZE - 1;
let column = (0..1 << LOG_SIZE).map(BaseField::from).collect_vec();
let mut expected = column.clone();
cpu_bit_reverse(&mut expected);
let mut res = column.iter().copied().collect::<BaseColumn>();
<SimdBackend as ColumnOps<BaseField>>::bit_reverse_column(&mut res);
assert_eq!(res.to_cpu(), expected);
}
#[test]
fn bit_reverse_large_column_works() {
const LOG_SIZE: u32 = MIN_LOG_SIZE;
let column = (0..1 << LOG_SIZE).map(BaseField::from).collect_vec();
let mut expected = column.clone();
cpu_bit_reverse(&mut expected);
let mut res = column.iter().copied().collect::<BaseColumn>();
<SimdBackend as ColumnOps<BaseField>>::bit_reverse_column(&mut res);
assert_eq!(res.to_cpu(), expected);
}
}

View File

@ -0,0 +1,412 @@
//! A SIMD implementation of the BLAKE2s compression function.
//! Based on <https://github.com/oconnor663/blake2_simd/blob/master/blake2s/src/avx2.rs>.
use std::array;
use std::iter::repeat;
use std::mem::transmute;
use std::simd::u32x16;
use bytemuck::cast_slice;
use itertools::Itertools;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
use super::m31::{LOG_N_LANES, N_LANES};
use super::SimdBackend;
use crate::core::backend::{Col, Column, ColumnOps};
use crate::core::fields::m31::BaseField;
use crate::core::vcs::blake2_hash::Blake2sHash;
use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher;
use crate::core::vcs::ops::{MerkleHasher, MerkleOps};
const IV: [u32; 8] = [
0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19,
];
pub const SIGMA: [[u8; 16]; 10] = [
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
[14, 10, 4, 8, 9, 15, 13, 6, 1, 12, 0, 2, 11, 7, 5, 3],
[11, 8, 12, 0, 5, 2, 15, 13, 10, 14, 3, 6, 7, 1, 9, 4],
[7, 9, 3, 1, 13, 12, 11, 14, 2, 6, 5, 10, 4, 0, 15, 8],
[9, 0, 5, 7, 2, 4, 10, 15, 14, 1, 11, 12, 6, 8, 3, 13],
[2, 12, 6, 10, 0, 11, 8, 3, 4, 13, 7, 5, 15, 14, 1, 9],
[12, 5, 1, 15, 14, 13, 4, 10, 0, 7, 6, 3, 9, 2, 8, 11],
[13, 11, 7, 14, 12, 1, 3, 9, 5, 0, 15, 4, 8, 6, 2, 10],
[6, 15, 14, 9, 11, 3, 0, 8, 12, 2, 13, 7, 1, 4, 10, 5],
[10, 2, 8, 4, 7, 6, 1, 5, 15, 11, 9, 14, 3, 12, 13, 0],
];
impl ColumnOps<Blake2sHash> for SimdBackend {
type Column = Vec<Blake2sHash>;
fn bit_reverse_column(_column: &mut Self::Column) {
unimplemented!()
}
}
impl MerkleOps<Blake2sMerkleHasher> for SimdBackend {
fn commit_on_layer(
log_size: u32,
prev_layer: Option<&Vec<Blake2sHash>>,
columns: &[&Col<Self, BaseField>],
) -> Vec<Blake2sHash> {
if log_size < LOG_N_LANES {
#[cfg(not(feature = "parallel"))]
let iter = 0..1 << log_size;
#[cfg(feature = "parallel")]
let iter = (0..1 << log_size).into_par_iter();
return iter
.map(|i| {
Blake2sMerkleHasher::hash_node(
prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])),
&columns.iter().map(|column| column.at(i)).collect_vec(),
)
})
.collect();
}
if let Some(prev_layer) = prev_layer {
assert_eq!(prev_layer.len(), 1 << (log_size + 1));
}
let zeros = u32x16::splat(0);
// Commit to columns.
let mut res = vec![Blake2sHash::default(); 1 << log_size];
#[cfg(not(feature = "parallel"))]
let iter = res.chunks_mut(1 << LOG_N_LANES);
#[cfg(feature = "parallel")]
let iter = res.par_chunks_mut(1 << LOG_N_LANES);
iter.enumerate().for_each(|(i, chunk)| {
let mut state: [u32x16; 8] = unsafe { std::mem::zeroed() };
// Hash prev_layer, if exists.
if let Some(prev_layer) = prev_layer {
let prev_chunk_u32s = cast_slice::<_, u32>(&prev_layer[(i << 5)..((i + 1) << 5)]);
// Note: prev_layer might be unaligned.
let msgs: [u32x16; 16] = array::from_fn(|j| {
u32x16::from_array(std::array::from_fn(|k| prev_chunk_u32s[16 * j + k]))
});
state = compress16(state, transpose_msgs(msgs), zeros, zeros, zeros, zeros);
}
// Hash columns in chunks of 16.
let mut col_chunk_iter = columns.array_chunks();
for col_chunk in &mut col_chunk_iter {
let msgs = col_chunk.map(|column| column.data[i].into_simd());
state = compress16(state, msgs, zeros, zeros, zeros, zeros);
}
// Hash remaining columns.
let remainder = col_chunk_iter.remainder();
if !remainder.is_empty() {
let msgs = remainder
.iter()
.map(|column| column.data[i].into_simd())
.chain(repeat(zeros))
.take(N_LANES)
.collect_vec()
.try_into()
.unwrap();
state = compress16(state, msgs, zeros, zeros, zeros, zeros);
}
let state: [Blake2sHash; 16] = unsafe { transmute(untranspose_states(state)) };
chunk.copy_from_slice(&state);
});
res
}
}
/// Applies [`u32::rotate_right(N)`] to each element of the vector
///
/// [`u32::rotate_right(N)`]: u32::rotate_right
#[inline(always)]
fn rotate<const N: u32>(x: u32x16) -> u32x16 {
(x >> N) | (x << (u32::BITS - N))
}
// `inline(always)` can cause code parsing errors for wasm: "locals exceed maximum".
#[cfg_attr(not(target_arch = "wasm32"), inline(always))]
pub fn round(v: &mut [u32x16; 16], m: [u32x16; 16], r: usize) {
v[0] += m[SIGMA[r][0] as usize];
v[1] += m[SIGMA[r][2] as usize];
v[2] += m[SIGMA[r][4] as usize];
v[3] += m[SIGMA[r][6] as usize];
v[0] += v[4];
v[1] += v[5];
v[2] += v[6];
v[3] += v[7];
v[12] ^= v[0];
v[13] ^= v[1];
v[14] ^= v[2];
v[15] ^= v[3];
v[12] = rotate::<16>(v[12]);
v[13] = rotate::<16>(v[13]);
v[14] = rotate::<16>(v[14]);
v[15] = rotate::<16>(v[15]);
v[8] += v[12];
v[9] += v[13];
v[10] += v[14];
v[11] += v[15];
v[4] ^= v[8];
v[5] ^= v[9];
v[6] ^= v[10];
v[7] ^= v[11];
v[4] = rotate::<12>(v[4]);
v[5] = rotate::<12>(v[5]);
v[6] = rotate::<12>(v[6]);
v[7] = rotate::<12>(v[7]);
v[0] += m[SIGMA[r][1] as usize];
v[1] += m[SIGMA[r][3] as usize];
v[2] += m[SIGMA[r][5] as usize];
v[3] += m[SIGMA[r][7] as usize];
v[0] += v[4];
v[1] += v[5];
v[2] += v[6];
v[3] += v[7];
v[12] ^= v[0];
v[13] ^= v[1];
v[14] ^= v[2];
v[15] ^= v[3];
v[12] = rotate::<8>(v[12]);
v[13] = rotate::<8>(v[13]);
v[14] = rotate::<8>(v[14]);
v[15] = rotate::<8>(v[15]);
v[8] += v[12];
v[9] += v[13];
v[10] += v[14];
v[11] += v[15];
v[4] ^= v[8];
v[5] ^= v[9];
v[6] ^= v[10];
v[7] ^= v[11];
v[4] = rotate::<7>(v[4]);
v[5] = rotate::<7>(v[5]);
v[6] = rotate::<7>(v[6]);
v[7] = rotate::<7>(v[7]);
v[0] += m[SIGMA[r][8] as usize];
v[1] += m[SIGMA[r][10] as usize];
v[2] += m[SIGMA[r][12] as usize];
v[3] += m[SIGMA[r][14] as usize];
v[0] += v[5];
v[1] += v[6];
v[2] += v[7];
v[3] += v[4];
v[15] ^= v[0];
v[12] ^= v[1];
v[13] ^= v[2];
v[14] ^= v[3];
v[15] = rotate::<16>(v[15]);
v[12] = rotate::<16>(v[12]);
v[13] = rotate::<16>(v[13]);
v[14] = rotate::<16>(v[14]);
v[10] += v[15];
v[11] += v[12];
v[8] += v[13];
v[9] += v[14];
v[5] ^= v[10];
v[6] ^= v[11];
v[7] ^= v[8];
v[4] ^= v[9];
v[5] = rotate::<12>(v[5]);
v[6] = rotate::<12>(v[6]);
v[7] = rotate::<12>(v[7]);
v[4] = rotate::<12>(v[4]);
v[0] += m[SIGMA[r][9] as usize];
v[1] += m[SIGMA[r][11] as usize];
v[2] += m[SIGMA[r][13] as usize];
v[3] += m[SIGMA[r][15] as usize];
v[0] += v[5];
v[1] += v[6];
v[2] += v[7];
v[3] += v[4];
v[15] ^= v[0];
v[12] ^= v[1];
v[13] ^= v[2];
v[14] ^= v[3];
v[15] = rotate::<8>(v[15]);
v[12] = rotate::<8>(v[12]);
v[13] = rotate::<8>(v[13]);
v[14] = rotate::<8>(v[14]);
v[10] += v[15];
v[11] += v[12];
v[8] += v[13];
v[9] += v[14];
v[5] ^= v[10];
v[6] ^= v[11];
v[7] ^= v[8];
v[4] ^= v[9];
v[5] = rotate::<7>(v[5]);
v[6] = rotate::<7>(v[6]);
v[7] = rotate::<7>(v[7]);
v[4] = rotate::<7>(v[4]);
}
/// Transposes input chunks (16 chunks of 16 `u32`s each), to get 16 `u32x16`, each
/// representing 16 packed instances of a message word.
fn transpose_msgs(mut data: [u32x16; 16]) -> [u32x16; 16] {
// Index abcd:xyzw, refers to a specific word in data as follows:
// abcd - chunk index (in base 2)
// xyzw - word offset (in base 2)
// Transpose by applying 4 times the index permutation:
// abcd:xyzw => wabc:dxyz
// In other words, rotate the index to the right by 1.
for _ in 0..4 {
let (d0, d8) = data[0].deinterleave(data[1]);
let (d1, d9) = data[2].deinterleave(data[3]);
let (d2, d10) = data[4].deinterleave(data[5]);
let (d3, d11) = data[6].deinterleave(data[7]);
let (d4, d12) = data[8].deinterleave(data[9]);
let (d5, d13) = data[10].deinterleave(data[11]);
let (d6, d14) = data[12].deinterleave(data[13]);
let (d7, d15) = data[14].deinterleave(data[15]);
data = [
d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, d14, d15,
];
}
data
}
fn untranspose_states(mut states: [u32x16; 8]) -> [u32x16; 8] {
// Index abc:xyzw, refers to a specific word in data as follows:
// abc - chunk index (in base 2)
// xyzw - word offset (in base 2)
// Transpose by applying 3 times the index permutation:
// abc:xyzw => bcx:yzwa
// In other words, rotate the index to the left by 1.
for _ in 0..3 {
let (d0, d1) = states[0].interleave(states[4]);
let (d2, d3) = states[1].interleave(states[5]);
let (d4, d5) = states[2].interleave(states[6]);
let (d6, d7) = states[3].interleave(states[7]);
states = [d0, d1, d2, d3, d4, d5, d6, d7];
}
states
}
/// Compresses 16 blake2s instances.
pub fn compress16(
h_vecs: [u32x16; 8],
msg_vecs: [u32x16; 16],
count_low: u32x16,
count_high: u32x16,
lastblock: u32x16,
lastnode: u32x16,
) -> [u32x16; 8] {
let mut v = [
h_vecs[0],
h_vecs[1],
h_vecs[2],
h_vecs[3],
h_vecs[4],
h_vecs[5],
h_vecs[6],
h_vecs[7],
u32x16::splat(IV[0]),
u32x16::splat(IV[1]),
u32x16::splat(IV[2]),
u32x16::splat(IV[3]),
u32x16::splat(IV[4]) ^ count_low,
u32x16::splat(IV[5]) ^ count_high,
u32x16::splat(IV[6]) ^ lastblock,
u32x16::splat(IV[7]) ^ lastnode,
];
round(&mut v, msg_vecs, 0);
round(&mut v, msg_vecs, 1);
round(&mut v, msg_vecs, 2);
round(&mut v, msg_vecs, 3);
round(&mut v, msg_vecs, 4);
round(&mut v, msg_vecs, 5);
round(&mut v, msg_vecs, 6);
round(&mut v, msg_vecs, 7);
round(&mut v, msg_vecs, 8);
round(&mut v, msg_vecs, 9);
[
h_vecs[0] ^ v[0] ^ v[8],
h_vecs[1] ^ v[1] ^ v[9],
h_vecs[2] ^ v[2] ^ v[10],
h_vecs[3] ^ v[3] ^ v[11],
h_vecs[4] ^ v[4] ^ v[12],
h_vecs[5] ^ v[5] ^ v[13],
h_vecs[6] ^ v[6] ^ v[14],
h_vecs[7] ^ v[7] ^ v[15],
]
}
#[cfg(test)]
mod tests {
use std::array;
use std::mem::transmute;
use std::simd::u32x16;
use aligned::{Aligned, A64};
use super::{compress16, transpose_msgs, untranspose_states};
use crate::core::vcs::blake2s_ref::compress;
#[test]
fn compress16_works() {
let states: Aligned<A64, [[u32; 8]; 16]> =
Aligned(array::from_fn(|i| array::from_fn(|j| (i + j) as u32)));
let msgs: Aligned<A64, [[u32; 16]; 16]> =
Aligned(array::from_fn(|i| array::from_fn(|j| (i + j + 20) as u32)));
let count_low = 1;
let count_high = 2;
let lastblock = 3;
let lastnode = 4;
let res_unvectorized = array::from_fn(|i| {
compress(
states[i], msgs[i], count_low, count_high, lastblock, lastnode,
)
});
let res_vectorized: [[u32; 8]; 16] = unsafe {
transmute(untranspose_states(compress16(
transpose_states(transmute(states)),
transpose_msgs(transmute(msgs)),
u32x16::splat(count_low),
u32x16::splat(count_high),
u32x16::splat(lastblock),
u32x16::splat(lastnode),
)))
};
assert_eq!(res_vectorized, res_unvectorized);
}
#[test]
fn untranspose_states_is_transpose_states_inverse() {
let states = array::from_fn(|i| u32x16::from(array::from_fn(|j| (i + j) as u32)));
let transposed_states = transpose_states(states);
let untrasponsed_transposed_states = untranspose_states(transposed_states);
assert_eq!(untrasponsed_transposed_states, states)
}
/// Transposes states, from 8 packed words, to get 16 results, each of size 32B.
fn transpose_states(mut states: [u32x16; 8]) -> [u32x16; 8] {
// Index abc:xyzw, refers to a specific word in data as follows:
// abc - chunk index (in base 2)
// xyzw - word offset (in base 2)
// Transpose by applying 3 times the index permutation:
// abc:xyzw => wab:cxyz
// In other words, rotate the index to the right by 1.
for _ in 0..3 {
let (s0, s4) = states[0].deinterleave(states[1]);
let (s1, s5) = states[2].deinterleave(states[3]);
let (s2, s6) = states[4].deinterleave(states[5]);
let (s3, s7) = states[6].deinterleave(states[7]);
states = [s0, s1, s2, s3, s4, s5, s6, s7];
}
states
}
}

View File

@ -0,0 +1,436 @@
use std::iter::zip;
use std::mem::transmute;
use bytemuck::{cast_slice, Zeroable};
use num_traits::One;
use super::fft::{ifft, rfft, CACHED_FFT_LOG_SIZE};
use super::m31::{PackedBaseField, LOG_N_LANES, N_LANES};
use super::qm31::PackedSecureField;
use super::SimdBackend;
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::{Col, CpuBackend};
use crate::core::circle::{CirclePoint, Coset};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::{Field, FieldExpOps};
use crate::core::poly::circle::{
CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly, PolyOps,
};
use crate::core::poly::twiddles::TwiddleTree;
use crate::core::poly::utils::{domain_line_twiddles_from_tree, fold};
use crate::core::poly::BitReversedOrder;
impl SimdBackend {
// TODO(Ohad): optimize.
fn twiddle_at<F: Field>(mappings: &[F], mut index: usize) -> F {
debug_assert!(
(1 << mappings.len()) as usize >= index,
"Index out of bounds. mappings log len = {}, index = {index}",
mappings.len().ilog2()
);
let mut product = F::one();
for &num in mappings.iter() {
if index & 1 == 1 {
product *= num;
}
index >>= 1;
if index == 0 {
break;
}
}
product
}
// TODO(Ohad): consider moving this to to a more general place.
// Note: CACHED_FFT_LOG_SIZE is specific to the backend.
fn generate_evaluation_mappings<F: Field>(point: CirclePoint<F>, log_size: u32) -> Vec<F> {
// Mappings are the factors used to compute the evaluation twiddle.
// Every twiddle (i) is of the form (m[0])^b_0 * (m[1])^b_1 * ... * (m[log_size -
// 1])^b_log_size.
// Where (m)_j are the mappings, and b_i is the j'th bit of i.
let mut mappings = vec![point.y, point.x];
let mut x = point.x;
for _ in 2..log_size {
x = CirclePoint::double_x(x);
mappings.push(x);
}
// The caller function expects the mapping in natural order. i.e. (y,x,h(x),h(h(x)),...).
// If the polynomial is large, the fft does a transpose in the middle in a granularity of 16
// (avx512). The coefficients would then be in tranposed order of 16-sized chunks.
// i.e. (a_(n-15), a_(n-14), ..., a_(n-1), a_(n-31), ..., a_(n-16), a_(n-32), ...).
// To compute the twiddles in the correct order, we need to transpose the coprresponding
// 'transposed bits' in the mappings. The result order of the mappings would then be
// (y, x, h(x), h^2(x), h^(log_n-1)(x), h^(log_n-2)(x) ...). To avoid code
// complexity for now, we just reverse the mappings, transpose, then reverse back.
// TODO(Ohad): optimize. consider changing the caller to expect the mappings in
// reversed-tranposed order.
if log_size > CACHED_FFT_LOG_SIZE {
mappings.reverse();
let n = mappings.len();
let n0 = (n - LOG_N_LANES as usize) / 2;
let n1 = (n - LOG_N_LANES as usize + 1) / 2;
let (ab, c) = mappings.split_at_mut(n1);
let (a, _b) = ab.split_at_mut(n0);
// Swap content of a,c.
a.swap_with_slice(&mut c[0..n0]);
mappings.reverse();
}
mappings
}
// Generates twiddle steps for efficiently computing the twiddles.
// steps[i] = t_i/(t_0*t_1*...*t_i-1).
fn twiddle_steps<F: Field>(mappings: &[F]) -> Vec<F>
where
F: FieldExpOps,
{
let mut denominators: Vec<F> = vec![mappings[0]];
for i in 1..mappings.len() {
denominators.push(denominators[i - 1] * mappings[i]);
}
let mut denom_inverses = vec![F::zero(); denominators.len()];
F::batch_inverse(&denominators, &mut denom_inverses);
let mut steps = vec![mappings[0]];
mappings
.iter()
.skip(1)
.zip(denom_inverses.iter())
.for_each(|(&m, &d)| {
steps.push(m * d);
});
steps.push(F::one());
steps
}
// Advances the twiddle by multiplying it by the next step. e.g:
// If idx(t) = 0b100..1010 , then f(t) = t * step[0]
// If idx(t) = 0b100..0111 , then f(t) = t * step[3]
fn advance_twiddle<F: Field>(twiddle: F, steps: &[F], curr_idx: usize) -> F {
twiddle * steps[curr_idx.trailing_ones() as usize]
}
}
// TODO(spapini): Everything is returned in redundant representation, where values can also be P.
// Decide if and when it's ok and what to do if it's not.
impl PolyOps for SimdBackend {
// The twiddles type is i32, and not BaseField. This is because the fast AVX mul implementation
// requries one of the numbers to be shifted left by 1 bit. This is not a reduced
// representation of the field.
type Twiddles = Vec<u32>;
fn new_canonical_ordered(
coset: CanonicCoset,
values: Col<Self, BaseField>,
) -> CircleEvaluation<Self, BaseField, BitReversedOrder> {
// TODO(spapini): Optimize.
let eval = CpuBackend::new_canonical_ordered(coset, values.into_cpu_vec());
CircleEvaluation::new(
eval.domain,
Col::<SimdBackend, BaseField>::from_iter(eval.values),
)
}
fn interpolate(
eval: CircleEvaluation<Self, BaseField, BitReversedOrder>,
twiddles: &TwiddleTree<Self>,
) -> CirclePoly<Self> {
let mut values = eval.values;
let log_size = values.length.ilog2();
let twiddles = domain_line_twiddles_from_tree(eval.domain, &twiddles.itwiddles);
// Safe because [PackedBaseField] is aligned on 64 bytes.
unsafe {
ifft::ifft(
transmute(values.data.as_mut_ptr()),
&twiddles,
log_size as usize,
);
}
// TODO(spapini): Fuse this multiplication / rotation.
let inv = PackedBaseField::broadcast(BaseField::from(eval.domain.size()).inverse());
values.data.iter_mut().for_each(|x| *x *= inv);
CirclePoly::new(values)
}
fn eval_at_point(poly: &CirclePoly<Self>, point: CirclePoint<SecureField>) -> SecureField {
// If the polynomial is small, fallback to evaluate directly.
// TODO(Ohad): it's possible to avoid falling back. Consider fixing.
if poly.log_size() <= 8 {
return slow_eval_at_point(poly, point);
}
let mappings = Self::generate_evaluation_mappings(point, poly.log_size());
// 8 lowest mappings produce the first 2^8 twiddles. Separate to optimize each calculation.
let (map_low, map_high) = mappings.split_at(4);
let twiddle_lows =
PackedSecureField::from_array(std::array::from_fn(|i| Self::twiddle_at(map_low, i)));
let (map_mid, map_high) = map_high.split_at(4);
let twiddle_mids =
PackedSecureField::from_array(std::array::from_fn(|i| Self::twiddle_at(map_mid, i)));
// Compute the high twiddle steps.
let twiddle_steps = Self::twiddle_steps(map_high);
// Every twiddle is a product of mappings that correspond to '1's in the bit representation
// of the current index. For every 2^n alligned chunk of 2^n elements, the twiddle
// array is the same, denoted twiddle_low. Use this to compute sums of (coeff *
// twiddle_high) mod 2^n, then multiply by twiddle_low, and sum to get the final result.
let mut sum = PackedSecureField::zeroed();
let mut twiddle_high = SecureField::one();
for (i, coeff_chunk) in poly.coeffs.data.array_chunks::<N_LANES>().enumerate() {
// For every chunk of 2 ^ 4 * 2 ^ 4 = 2 ^ 8 elements, the twiddle high is the same.
// Multiply it by every mid twiddle factor to get the factors for the current chunk.
let high_twiddle_factors =
(PackedSecureField::broadcast(twiddle_high) * twiddle_mids).to_array();
// Sum the coefficients multiplied by each corrseponsing twiddle. Result is effectivley
// an array[16] where the value at index 'i' is the sum of all coefficients at indices
// that are i mod 16.
for (&packed_coeffs, mid_twiddle) in zip(coeff_chunk, high_twiddle_factors) {
sum += PackedSecureField::broadcast(mid_twiddle) * packed_coeffs;
}
// Advance twiddle high.
twiddle_high = Self::advance_twiddle(twiddle_high, &twiddle_steps, i);
}
(sum * twiddle_lows).pointwise_sum()
}
fn extend(poly: &CirclePoly<Self>, log_size: u32) -> CirclePoly<Self> {
// TODO(spapini): Optimize or get rid of extend.
poly.evaluate(CanonicCoset::new(log_size).circle_domain())
.interpolate()
}
fn evaluate(
poly: &CirclePoly<Self>,
domain: CircleDomain,
twiddles: &TwiddleTree<Self>,
) -> CircleEvaluation<Self, BaseField, BitReversedOrder> {
// TODO(spapini): Precompute twiddles.
// TODO(spapini): Handle small cases.
let log_size = domain.log_size();
let fft_log_size = poly.log_size();
assert!(
log_size >= fft_log_size,
"Can only evaluate on larger domains"
);
let twiddles = domain_line_twiddles_from_tree(domain, &twiddles.twiddles);
// Evaluate on a big domains by evaluating on several subdomains.
let log_subdomains = log_size - fft_log_size;
// Allocate the destination buffer without initializing.
let mut values = Vec::with_capacity(domain.size() >> LOG_N_LANES);
#[allow(clippy::uninit_vec)]
unsafe {
values.set_len(domain.size() >> LOG_N_LANES)
};
for i in 0..(1 << log_subdomains) {
// The subdomain twiddles are a slice of the large domain twiddles.
let subdomain_twiddles = (0..(fft_log_size - 1))
.map(|layer_i| {
&twiddles[layer_i as usize]
[i << (fft_log_size - 2 - layer_i)..(i + 1) << (fft_log_size - 2 - layer_i)]
})
.collect::<Vec<_>>();
// FFT from the coefficients buffer to the values chunk.
unsafe {
rfft::fft(
transmute(poly.coeffs.data.as_ptr()),
transmute(
values[i << (fft_log_size - LOG_N_LANES)
..(i + 1) << (fft_log_size - LOG_N_LANES)]
.as_mut_ptr(),
),
&subdomain_twiddles,
fft_log_size as usize,
);
}
}
CircleEvaluation::new(
domain,
BaseColumn {
data: values,
length: domain.size(),
},
)
}
fn precompute_twiddles(coset: Coset) -> TwiddleTree<Self> {
let mut twiddles = Vec::with_capacity(coset.size());
let mut itwiddles = Vec::with_capacity(coset.size());
// TODO(spapini): Optimize.
for layer in &rfft::get_twiddle_dbls(coset) {
twiddles.extend(layer);
}
// Pad by any value, to make the size a power of 2.
twiddles.push(1);
assert_eq!(twiddles.len(), coset.size());
for layer in &ifft::get_itwiddle_dbls(coset) {
itwiddles.extend(layer);
}
// Pad by any value, to make the size a power of 2.
itwiddles.push(1);
assert_eq!(itwiddles.len(), coset.size());
TwiddleTree {
root_coset: coset,
twiddles,
itwiddles,
}
}
}
fn slow_eval_at_point(
poly: &CirclePoly<SimdBackend>,
point: CirclePoint<SecureField>,
) -> SecureField {
let mut mappings = vec![point.y, point.x];
let mut x = point.x;
for _ in 2..poly.log_size() {
x = CirclePoint::double_x(x);
mappings.push(x);
}
mappings.reverse();
// If the polynomial is large, the fft does a transpose in the middle.
if poly.log_size() > CACHED_FFT_LOG_SIZE {
let n = mappings.len();
let n0 = (n - LOG_N_LANES as usize) / 2;
let n1 = (n - LOG_N_LANES as usize + 1) / 2;
let (ab, c) = mappings.split_at_mut(n1);
let (a, _b) = ab.split_at_mut(n0);
// Swap content of a,c.
a.swap_with_slice(&mut c[0..n0]);
}
fold(cast_slice::<_, BaseField>(&poly.coeffs.data), &mappings)
}
#[cfg(test)]
mod tests {
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use crate::core::backend::simd::circle::slow_eval_at_point;
use crate::core::backend::simd::fft::{CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE};
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::Column;
use crate::core::circle::CirclePoint;
use crate::core::fields::m31::BaseField;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, CirclePoly, PolyOps};
use crate::core::poly::{BitReversedOrder, NaturalOrder};
#[test]
fn test_interpolate_and_eval() {
for log_size in MIN_FFT_LOG_SIZE..CACHED_FFT_LOG_SIZE + 4 {
let domain = CanonicCoset::new(log_size).circle_domain();
let evaluation = CircleEvaluation::<SimdBackend, BaseField, BitReversedOrder>::new(
domain,
(0..1 << log_size).map(BaseField::from).collect(),
);
let poly = evaluation.clone().interpolate();
let evaluation2 = poly.evaluate(domain);
assert_eq!(evaluation.values.to_cpu(), evaluation2.values.to_cpu());
}
}
#[test]
fn test_eval_extension() {
for log_size in MIN_FFT_LOG_SIZE..CACHED_FFT_LOG_SIZE + 2 {
let domain = CanonicCoset::new(log_size).circle_domain();
let domain_ext = CanonicCoset::new(log_size + 2).circle_domain();
let evaluation = CircleEvaluation::<SimdBackend, BaseField, BitReversedOrder>::new(
domain,
(0..1 << log_size).map(BaseField::from).collect(),
);
let poly = evaluation.clone().interpolate();
let evaluation2 = poly.evaluate(domain_ext);
assert_eq!(
poly.extend(log_size + 2).coeffs.to_cpu(),
evaluation2.interpolate().coeffs.to_cpu()
);
}
}
#[test]
fn test_eval_at_point() {
for log_size in MIN_FFT_LOG_SIZE + 1..CACHED_FFT_LOG_SIZE + 4 {
let domain = CanonicCoset::new(log_size).circle_domain();
let evaluation = CircleEvaluation::<SimdBackend, BaseField, NaturalOrder>::new(
domain,
(0..1 << log_size).map(BaseField::from).collect(),
);
let poly = evaluation.bit_reverse().interpolate();
for i in [0, 1, 3, 1 << (log_size - 1), 1 << (log_size - 2)] {
let p = domain.at(i);
let eval = poly.eval_at_point(p.into_ef());
assert_eq!(
eval,
BaseField::from(i).into(),
"log_size={log_size}, i={i}"
);
}
}
}
#[test]
fn test_circle_poly_extend() {
for log_size in MIN_FFT_LOG_SIZE..CACHED_FFT_LOG_SIZE + 2 {
let poly =
CirclePoly::<SimdBackend>::new((0..1 << log_size).map(BaseField::from).collect());
let eval0 = poly.evaluate(CanonicCoset::new(log_size + 2).circle_domain());
let eval1 = poly
.extend(log_size + 2)
.evaluate(CanonicCoset::new(log_size + 2).circle_domain());
assert_eq!(eval0.values.to_cpu(), eval1.values.to_cpu());
}
}
#[test]
fn test_eval_securefield() {
let mut rng = SmallRng::seed_from_u64(0);
for log_size in MIN_FFT_LOG_SIZE..CACHED_FFT_LOG_SIZE + 2 {
let domain = CanonicCoset::new(log_size).circle_domain();
let evaluation = CircleEvaluation::<SimdBackend, BaseField, NaturalOrder>::new(
domain,
(0..1 << log_size).map(BaseField::from).collect(),
);
let poly = evaluation.bit_reverse().interpolate();
let x = rng.gen();
let y = rng.gen();
let p = CirclePoint { x, y };
let eval = PolyOps::eval_at_point(&poly, p);
assert_eq!(eval, slow_eval_at_point(&poly, p), "log_size = {log_size}");
}
}
}

View File

@ -0,0 +1,230 @@
use std::array;
use std::ops::{Add, Mul, MulAssign, Neg, Sub};
use bytemuck::{Pod, Zeroable};
use num_traits::{One, Zero};
use super::m31::{PackedM31, N_LANES};
use crate::core::fields::cm31::CM31;
use crate::core::fields::FieldExpOps;
/// SIMD implementation of [`CM31`].
#[derive(Copy, Clone, Debug)]
pub struct PackedCM31(pub [PackedM31; 2]);
impl PackedCM31 {
/// Constructs a new instance with all vector elements set to `value`.
pub fn broadcast(value: CM31) -> Self {
Self([PackedM31::broadcast(value.0), PackedM31::broadcast(value.1)])
}
/// Returns all `a` values such that each vector element is represented as `a + bi`.
pub fn a(&self) -> PackedM31 {
self.0[0]
}
/// Returns all `b` values such that each vector element is represented as `a + bi`.
pub fn b(&self) -> PackedM31 {
self.0[1]
}
pub fn to_array(&self) -> [CM31; N_LANES] {
let a = self.a().to_array();
let b = self.b().to_array();
array::from_fn(|i| CM31(a[i], b[i]))
}
pub fn from_array(values: [CM31; N_LANES]) -> Self {
Self([
PackedM31::from_array(values.map(|v| v.0)),
PackedM31::from_array(values.map(|v| v.1)),
])
}
/// Interleaves two vectors.
pub fn interleave(self, other: Self) -> (Self, Self) {
let Self([a_evens, b_evens]) = self;
let Self([a_odds, b_odds]) = other;
let (a_lhs, a_rhs) = a_evens.interleave(a_odds);
let (b_lhs, b_rhs) = b_evens.interleave(b_odds);
(Self([a_lhs, b_lhs]), Self([a_rhs, b_rhs]))
}
/// Deinterleaves two vectors.
pub fn deinterleave(self, other: Self) -> (Self, Self) {
let Self([a_self, b_self]) = self;
let Self([a_other, b_other]) = other;
let (a_evens, a_odds) = a_self.deinterleave(a_other);
let (b_evens, b_odds) = b_self.deinterleave(b_other);
(Self([a_evens, b_evens]), Self([a_odds, b_odds]))
}
/// Doubles each element in the vector.
pub fn double(self) -> Self {
let Self([a, b]) = self;
Self([a.double(), b.double()])
}
}
impl Add for PackedCM31 {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
Self([self.a() + rhs.a(), self.b() + rhs.b()])
}
}
impl Sub for PackedCM31 {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
Self([self.a() - rhs.a(), self.b() - rhs.b()])
}
}
impl Mul for PackedCM31 {
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
// Compute using Karatsuba.
let ac = self.a() * rhs.a();
let bd = self.b() * rhs.b();
// Computes (a + b) * (c + d).
let ab_t_cd = (self.a() + self.b()) * (rhs.a() + rhs.b());
// (ac - bd) + (ad + bc)i.
Self([ac - bd, ab_t_cd - ac - bd])
}
}
impl Zero for PackedCM31 {
fn zero() -> Self {
Self([PackedM31::zero(), PackedM31::zero()])
}
fn is_zero(&self) -> bool {
self.a().is_zero() && self.b().is_zero()
}
}
unsafe impl Pod for PackedCM31 {}
unsafe impl Zeroable for PackedCM31 {
fn zeroed() -> Self {
unsafe { core::mem::zeroed() }
}
}
impl One for PackedCM31 {
fn one() -> Self {
Self([PackedM31::one(), PackedM31::zero()])
}
}
impl MulAssign for PackedCM31 {
fn mul_assign(&mut self, rhs: Self) {
*self = *self * rhs;
}
}
impl FieldExpOps for PackedCM31 {
fn inverse(&self) -> Self {
assert!(!self.is_zero(), "0 has no inverse");
// 1 / (a + bi) = (a - bi) / (a^2 + b^2).
Self([self.a(), -self.b()]) * (self.a().square() + self.b().square()).inverse()
}
}
impl Add<PackedM31> for PackedCM31 {
type Output = Self;
fn add(self, rhs: PackedM31) -> Self::Output {
Self([self.a() + rhs, self.b()])
}
}
impl Sub<PackedM31> for PackedCM31 {
type Output = Self;
fn sub(self, rhs: PackedM31) -> Self::Output {
let Self([a, b]) = self;
Self([a - rhs, b])
}
}
impl Mul<PackedM31> for PackedCM31 {
type Output = Self;
fn mul(self, rhs: PackedM31) -> Self::Output {
let Self([a, b]) = self;
Self([a * rhs, b * rhs])
}
}
impl Neg for PackedCM31 {
type Output = Self;
fn neg(self) -> Self::Output {
let Self([a, b]) = self;
Self([-a, -b])
}
}
#[cfg(test)]
mod tests {
use std::array;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use crate::core::backend::simd::cm31::PackedCM31;
#[test]
fn addition_works() {
let mut rng = SmallRng::seed_from_u64(0);
let lhs = rng.gen();
let rhs = rng.gen();
let packed_lhs = PackedCM31::from_array(lhs);
let packed_rhs = PackedCM31::from_array(rhs);
let res = packed_lhs + packed_rhs;
assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] + rhs[i]));
}
#[test]
fn subtraction_works() {
let mut rng = SmallRng::seed_from_u64(0);
let lhs = rng.gen();
let rhs = rng.gen();
let packed_lhs = PackedCM31::from_array(lhs);
let packed_rhs = PackedCM31::from_array(rhs);
let res = packed_lhs - packed_rhs;
assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] - rhs[i]));
}
#[test]
fn multiplication_works() {
let mut rng = SmallRng::seed_from_u64(0);
let lhs = rng.gen();
let rhs = rng.gen();
let packed_lhs = PackedCM31::from_array(lhs);
let packed_rhs = PackedCM31::from_array(rhs);
let res = packed_lhs * packed_rhs;
assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] * rhs[i]));
}
#[test]
fn negation_works() {
let mut rng = SmallRng::seed_from_u64(0);
let values = rng.gen();
let packed_values = PackedCM31::from_array(values);
let res = -packed_values;
assert_eq!(res.to_array(), values.map(|v| -v));
}
}

View File

@ -0,0 +1,656 @@
use std::iter::zip;
use std::{array, mem};
use bytemuck::allocation::cast_vec;
use bytemuck::{cast_slice, cast_slice_mut, Zeroable};
use itertools::{izip, Itertools};
use num_traits::Zero;
use super::cm31::PackedCM31;
use super::m31::{PackedBaseField, N_LANES};
use super::qm31::{PackedQM31, PackedSecureField};
use super::very_packed_m31::{VeryPackedBaseField, VeryPackedSecureField, N_VERY_PACKED_ELEMS};
use super::SimdBackend;
use crate::core::backend::{Column, CpuBackend};
use crate::core::fields::cm31::CM31;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::{SecureColumnByCoords, SECURE_EXTENSION_DEGREE};
use crate::core::fields::{FieldExpOps, FieldOps};
impl FieldOps<BaseField> for SimdBackend {
fn batch_inverse(column: &BaseColumn, dst: &mut BaseColumn) {
PackedBaseField::batch_inverse(&column.data, &mut dst.data);
}
}
impl FieldOps<SecureField> for SimdBackend {
fn batch_inverse(column: &SecureColumn, dst: &mut SecureColumn) {
PackedSecureField::batch_inverse(&column.data, &mut dst.data);
}
}
/// An efficient structure for storing and operating on a arbitrary number of [`BaseField`] values.
#[derive(Clone, Debug)]
pub struct BaseColumn {
pub data: Vec<PackedBaseField>,
/// The number of [`BaseField`]s in the vector.
pub length: usize,
}
impl BaseColumn {
/// Extracts a slice containing the entire vector of [`BaseField`]s.
pub fn as_slice(&self) -> &[BaseField] {
&cast_slice(&self.data)[..self.length]
}
/// Extracts a mutable slice containing the entire vector of [`BaseField`]s.
pub fn as_mut_slice(&mut self) -> &mut [BaseField] {
&mut cast_slice_mut(&mut self.data)[..self.length]
}
pub fn into_cpu_vec(mut self) -> Vec<BaseField> {
let capacity = self.data.capacity() * N_LANES;
let length = self.length;
let ptr = self.data.as_mut_ptr() as *mut BaseField;
let res = unsafe { Vec::from_raw_parts(ptr, length, capacity) };
mem::forget(self);
res
}
/// Returns a vector of `BaseColumnMutSlice`s, each mutably owning
/// `chunk_size` `PackedBaseField`s (i.e, `chuck_size` * `N_LANES` elements).
pub fn chunks_mut(&mut self, chunk_size: usize) -> Vec<BaseColumnMutSlice<'_>> {
self.data
.chunks_mut(chunk_size)
.map(BaseColumnMutSlice)
.collect_vec()
}
pub fn into_secure_column(self) -> SecureColumn {
let length = self.len();
let data = self.data.into_iter().map(PackedSecureField::from).collect();
SecureColumn { data, length }
}
}
impl Column<BaseField> for BaseColumn {
fn zeros(length: usize) -> Self {
let data = vec![PackedBaseField::zeroed(); length.div_ceil(N_LANES)];
Self { data, length }
}
#[allow(clippy::uninit_vec)]
unsafe fn uninitialized(length: usize) -> Self {
let mut data = Vec::with_capacity(length.div_ceil(N_LANES));
data.set_len(length.div_ceil(N_LANES));
Self { data, length }
}
fn to_cpu(&self) -> Vec<BaseField> {
self.as_slice().to_vec()
}
fn len(&self) -> usize {
self.length
}
fn at(&self, index: usize) -> BaseField {
self.data[index / N_LANES].to_array()[index % N_LANES]
}
fn set(&mut self, index: usize, value: BaseField) {
let mut packed = self.data[index / N_LANES].to_array();
packed[index % N_LANES] = value;
self.data[index / N_LANES] = PackedBaseField::from_array(packed)
}
}
impl FromIterator<BaseField> for BaseColumn {
fn from_iter<I: IntoIterator<Item = BaseField>>(iter: I) -> Self {
let mut chunks = iter.into_iter().array_chunks();
let mut data = (&mut chunks).map(PackedBaseField::from_array).collect_vec();
let mut length = data.len() * N_LANES;
if let Some(remainder) = chunks.into_remainder() {
if !remainder.is_empty() {
length += remainder.len();
let mut last = [BaseField::zero(); N_LANES];
last[..remainder.len()].copy_from_slice(remainder.as_slice());
data.push(PackedBaseField::from_array(last));
}
}
Self { data, length }
}
}
// A efficient structure for storing and operating on a arbitrary number of [`SecureField`] values.
#[derive(Clone, Debug)]
pub struct CM31Column {
pub data: Vec<PackedCM31>,
pub length: usize,
}
impl Column<CM31> for CM31Column {
fn zeros(length: usize) -> Self {
Self {
data: vec![PackedCM31::zeroed(); length.div_ceil(N_LANES)],
length,
}
}
#[allow(clippy::uninit_vec)]
unsafe fn uninitialized(length: usize) -> Self {
let mut data = Vec::with_capacity(length.div_ceil(N_LANES));
data.set_len(length.div_ceil(N_LANES));
Self { data, length }
}
fn to_cpu(&self) -> Vec<CM31> {
self.data
.iter()
.flat_map(|x| x.to_array())
.take(self.length)
.collect()
}
fn len(&self) -> usize {
self.length
}
fn at(&self, index: usize) -> CM31 {
self.data[index / N_LANES].to_array()[index % N_LANES]
}
fn set(&mut self, index: usize, value: CM31) {
let mut packed = self.data[index / N_LANES].to_array();
packed[index % N_LANES] = value;
self.data[index / N_LANES] = PackedCM31::from_array(packed)
}
}
impl FromIterator<CM31> for CM31Column {
fn from_iter<I: IntoIterator<Item = CM31>>(iter: I) -> Self {
let mut chunks = iter.into_iter().array_chunks();
let mut data = (&mut chunks).map(PackedCM31::from_array).collect_vec();
let mut length = data.len() * N_LANES;
if let Some(remainder) = chunks.into_remainder() {
if !remainder.is_empty() {
length += remainder.len();
let mut last = [CM31::zero(); N_LANES];
last[..remainder.len()].copy_from_slice(remainder.as_slice());
data.push(PackedCM31::from_array(last));
}
}
Self { data, length }
}
}
impl FromIterator<PackedCM31> for CM31Column {
fn from_iter<I: IntoIterator<Item = PackedCM31>>(iter: I) -> Self {
let data = (&mut iter.into_iter()).collect_vec();
let length = data.len() * N_LANES;
Self { data, length }
}
}
/// A mutable slice of a BaseColumn.
pub struct BaseColumnMutSlice<'a>(pub &'a mut [PackedBaseField]);
impl<'a> BaseColumnMutSlice<'a> {
pub fn at(&self, index: usize) -> BaseField {
self.0[index / N_LANES].to_array()[index % N_LANES]
}
pub fn set(&mut self, index: usize, value: BaseField) {
let mut packed = self.0[index / N_LANES].to_array();
packed[index % N_LANES] = value;
self.0[index / N_LANES] = PackedBaseField::from_array(packed)
}
}
/// An efficient structure for storing and operating on a arbitrary number of [`SecureField`]
/// values.
#[derive(Clone, Debug)]
pub struct SecureColumn {
pub data: Vec<PackedSecureField>,
/// The number of [`SecureField`]s in the vector.
pub length: usize,
}
impl SecureColumn {
// Separates a single column of `PackedSecureField` elements into `SECURE_EXTENSION_DEGREE` many
// `PackedBaseField` coordinate columns.
pub fn into_secure_column_by_coords(self) -> SecureColumnByCoords<SimdBackend> {
if self.len() < N_LANES {
return self.to_cpu().into_iter().collect();
}
let length = self.length;
let packed_length = self.data.len();
let mut columns = array::from_fn(|_| Vec::with_capacity(packed_length));
for v in self.data {
let packed_coords = v.into_packed_m31s();
zip(&mut columns, packed_coords).for_each(|(col, packed_coord)| col.push(packed_coord));
}
SecureColumnByCoords {
columns: columns.map(|col| BaseColumn { data: col, length }),
}
}
}
impl Column<SecureField> for SecureColumn {
fn zeros(length: usize) -> Self {
Self {
data: vec![PackedSecureField::zeroed(); length.div_ceil(N_LANES)],
length,
}
}
#[allow(clippy::uninit_vec)]
unsafe fn uninitialized(length: usize) -> Self {
let mut data = Vec::with_capacity(length.div_ceil(N_LANES));
data.set_len(length.div_ceil(N_LANES));
Self { data, length }
}
fn to_cpu(&self) -> Vec<SecureField> {
self.data
.iter()
.flat_map(|x| x.to_array())
.take(self.length)
.collect()
}
fn len(&self) -> usize {
self.length
}
fn at(&self, index: usize) -> SecureField {
self.data[index / N_LANES].to_array()[index % N_LANES]
}
fn set(&mut self, index: usize, value: SecureField) {
let mut packed = self.data[index / N_LANES].to_array();
packed[index % N_LANES] = value;
self.data[index / N_LANES] = PackedSecureField::from_array(packed)
}
}
impl FromIterator<SecureField> for SecureColumn {
fn from_iter<I: IntoIterator<Item = SecureField>>(iter: I) -> Self {
let mut chunks = iter.into_iter().array_chunks();
let mut data = (&mut chunks)
.map(PackedSecureField::from_array)
.collect_vec();
let mut length = data.len() * N_LANES;
if let Some(remainder) = chunks.into_remainder() {
if !remainder.is_empty() {
length += remainder.len();
let mut last = [SecureField::zero(); N_LANES];
last[..remainder.len()].copy_from_slice(remainder.as_slice());
data.push(PackedSecureField::from_array(last));
}
}
Self { data, length }
}
}
impl FromIterator<PackedSecureField> for SecureColumn {
fn from_iter<I: IntoIterator<Item = PackedSecureField>>(iter: I) -> Self {
let data = iter.into_iter().collect_vec();
let length = data.len() * N_LANES;
Self { data, length }
}
}
/// A mutable slice of a SecureColumnByCoords.
pub struct SecureColumnByCoordsMutSlice<'a>(pub [BaseColumnMutSlice<'a>; SECURE_EXTENSION_DEGREE]);
impl<'a> SecureColumnByCoordsMutSlice<'a> {
/// # Safety
///
/// `vec_index` must be a valid index.
pub unsafe fn packed_at(&self, vec_index: usize) -> PackedSecureField {
PackedQM31([
PackedCM31([
*self.0[0].0.get_unchecked(vec_index),
*self.0[1].0.get_unchecked(vec_index),
]),
PackedCM31([
*self.0[2].0.get_unchecked(vec_index),
*self.0[3].0.get_unchecked(vec_index),
]),
])
}
/// # Safety
///
/// `vec_index` must be a valid index.
pub unsafe fn set_packed(&mut self, vec_index: usize, value: PackedSecureField) {
let PackedQM31([PackedCM31([a, b]), PackedCM31([c, d])]) = value;
*self.0[0].0.get_unchecked_mut(vec_index) = a;
*self.0[1].0.get_unchecked_mut(vec_index) = b;
*self.0[2].0.get_unchecked_mut(vec_index) = c;
*self.0[3].0.get_unchecked_mut(vec_index) = d;
}
}
impl SecureColumnByCoords<SimdBackend> {
pub fn packed_len(&self) -> usize {
self.columns[0].data.len()
}
/// # Safety
///
/// `vec_index` must be a valid index.
pub unsafe fn packed_at(&self, vec_index: usize) -> PackedSecureField {
PackedQM31([
PackedCM31([
*self.columns[0].data.get_unchecked(vec_index),
*self.columns[1].data.get_unchecked(vec_index),
]),
PackedCM31([
*self.columns[2].data.get_unchecked(vec_index),
*self.columns[3].data.get_unchecked(vec_index),
]),
])
}
/// # Safety
///
/// `vec_index` must be a valid index.
pub unsafe fn set_packed(&mut self, vec_index: usize, value: PackedSecureField) {
let PackedQM31([PackedCM31([a, b]), PackedCM31([c, d])]) = value;
*self.columns[0].data.get_unchecked_mut(vec_index) = a;
*self.columns[1].data.get_unchecked_mut(vec_index) = b;
*self.columns[2].data.get_unchecked_mut(vec_index) = c;
*self.columns[3].data.get_unchecked_mut(vec_index) = d;
}
pub fn to_vec(&self) -> Vec<SecureField> {
izip!(
self.columns[0].to_cpu(),
self.columns[1].to_cpu(),
self.columns[2].to_cpu(),
self.columns[3].to_cpu(),
)
.map(|(a, b, c, d)| SecureField::from_m31_array([a, b, c, d]))
.collect()
}
/// Returns a vector of `SecureColumnByCoordsMutSlice`s, each mutably owning
/// `SECURE_EXTENSION_DEGREE` slices of `chunk_size` `PackedBaseField`s
/// (i.e, `chuck_size` * `N_LANES` secure field elements, by coordinates).
pub fn chunks_mut(&mut self, chunk_size: usize) -> Vec<SecureColumnByCoordsMutSlice<'_>> {
let [a, b, c, d] = self
.columns
.get_many_mut([0, 1, 2, 3])
.unwrap()
.map(|x| x.chunks_mut(chunk_size));
izip!(a, b, c, d)
.map(|(a, b, c, d)| SecureColumnByCoordsMutSlice([a, b, c, d]))
.collect_vec()
}
}
impl FromIterator<SecureField> for SecureColumnByCoords<SimdBackend> {
fn from_iter<I: IntoIterator<Item = SecureField>>(iter: I) -> Self {
let cpu_col = SecureColumnByCoords::<CpuBackend>::from_iter(iter);
let columns = cpu_col.columns.map(|col| col.into_iter().collect());
SecureColumnByCoords { columns }
}
}
#[derive(Clone, Debug)]
pub struct VeryPackedBaseColumn {
pub data: Vec<VeryPackedBaseField>,
/// The number of [`BaseField`]s in the vector.
pub length: usize,
}
impl VeryPackedBaseColumn {
/// Transforms a `&BaseColumn` to a `&VeryPackedBaseColumn`.
/// # Safety
///
/// The resulting pointer does not update the underlying `data`'s length.
pub unsafe fn transform_under_ref(value: &BaseColumn) -> &Self {
&*(std::ptr::addr_of!(*value) as *const VeryPackedBaseColumn)
}
}
impl From<BaseColumn> for VeryPackedBaseColumn {
fn from(value: BaseColumn) -> Self {
Self {
data: cast_vec(value.data),
length: value.length,
}
}
}
impl From<VeryPackedBaseColumn> for BaseColumn {
fn from(value: VeryPackedBaseColumn) -> Self {
Self {
data: cast_vec(value.data),
length: value.length,
}
}
}
impl FromIterator<BaseField> for VeryPackedBaseColumn {
fn from_iter<I: IntoIterator<Item = BaseField>>(iter: I) -> Self {
BaseColumn::from_iter(iter).into()
}
}
impl Column<BaseField> for VeryPackedBaseColumn {
fn zeros(length: usize) -> Self {
BaseColumn::zeros(length).into()
}
#[allow(clippy::uninit_vec)]
unsafe fn uninitialized(length: usize) -> Self {
BaseColumn::uninitialized(length).into()
}
fn to_cpu(&self) -> Vec<BaseField> {
todo!()
}
fn len(&self) -> usize {
self.length
}
fn at(&self, index: usize) -> BaseField {
let chunk_size = N_LANES * N_VERY_PACKED_ELEMS;
self.data[index / chunk_size].to_array()[index % chunk_size]
}
fn set(&mut self, index: usize, value: BaseField) {
let chunk_size = N_LANES * N_VERY_PACKED_ELEMS;
let mut packed = self.data[index / chunk_size].to_array();
packed[index % chunk_size] = value;
self.data[index / chunk_size] = VeryPackedBaseField::from_array(packed)
}
}
#[derive(Clone, Debug)]
pub struct VeryPackedSecureColumnByCoords {
pub columns: [VeryPackedBaseColumn; SECURE_EXTENSION_DEGREE],
}
impl From<SecureColumnByCoords<SimdBackend>> for VeryPackedSecureColumnByCoords {
fn from(value: SecureColumnByCoords<SimdBackend>) -> Self {
Self {
columns: value
.columns
.into_iter()
.map(VeryPackedBaseColumn::from)
.collect_vec()
.try_into()
.unwrap(),
}
}
}
impl From<VeryPackedSecureColumnByCoords> for SecureColumnByCoords<SimdBackend> {
fn from(value: VeryPackedSecureColumnByCoords) -> Self {
Self {
columns: value
.columns
.into_iter()
.map(BaseColumn::from)
.collect_vec()
.try_into()
.unwrap(),
}
}
}
impl VeryPackedSecureColumnByCoords {
pub fn packed_len(&self) -> usize {
self.columns[0].data.len()
}
/// # Safety
///
/// `vec_index` must be a valid index.
pub unsafe fn packed_at(&self, vec_index: usize) -> VeryPackedSecureField {
VeryPackedSecureField::from_fn(|i| {
PackedQM31([
PackedCM31([
self.columns[0].data.get_unchecked(vec_index).0[i],
self.columns[1].data.get_unchecked(vec_index).0[i],
]),
PackedCM31([
self.columns[2].data.get_unchecked(vec_index).0[i],
self.columns[3].data.get_unchecked(vec_index).0[i],
]),
])
})
}
/// # Safety
///
/// `vec_index` must be a valid index.
pub unsafe fn set_packed(&mut self, vec_index: usize, value: VeryPackedSecureField) {
for i in 0..N_VERY_PACKED_ELEMS {
let PackedQM31([PackedCM31([a, b]), PackedCM31([c, d])]) = value.0[i];
self.columns[0].data.get_unchecked_mut(vec_index).0[i] = a;
self.columns[1].data.get_unchecked_mut(vec_index).0[i] = b;
self.columns[2].data.get_unchecked_mut(vec_index).0[i] = c;
self.columns[3].data.get_unchecked_mut(vec_index).0[i] = d;
}
}
pub fn to_vec(&self) -> Vec<SecureField> {
izip!(
self.columns[0].to_cpu(),
self.columns[1].to_cpu(),
self.columns[2].to_cpu(),
self.columns[3].to_cpu(),
)
.map(|(a, b, c, d)| SecureField::from_m31_array([a, b, c, d]))
.collect()
}
/// Transforms a `&mut SecureColumnByCoords<SimdBackend>` to a
/// `&mut VeryPackedSecureColumnByCoords`.
///
/// # Safety
///
/// The resulting pointer does not update the underlying columns' `data`'s lengths.
pub unsafe fn transform_under_mut(value: &mut SecureColumnByCoords<SimdBackend>) -> &mut Self {
&mut *(std::ptr::addr_of!(*value) as *mut VeryPackedSecureColumnByCoords)
}
}
#[cfg(test)]
mod tests {
use std::array;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use super::BaseColumn;
use crate::core::backend::simd::column::SecureColumn;
use crate::core::backend::simd::m31::N_LANES;
use crate::core::backend::simd::qm31::PackedQM31;
use crate::core::backend::Column;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumnByCoords;
#[test]
fn base_field_vec_from_iter_works() {
let values: [BaseField; 30] = array::from_fn(BaseField::from);
let res = values.into_iter().collect::<BaseColumn>();
assert_eq!(res.to_cpu(), values);
}
#[test]
fn secure_field_vec_from_iter_works() {
let mut rng = SmallRng::seed_from_u64(0);
let values: [SecureField; 30] = rng.gen();
let res = values.into_iter().collect::<SecureColumn>();
assert_eq!(res.to_cpu(), values);
}
#[test]
fn test_base_column_chunks_mut() {
let values: [BaseField; N_LANES * 7] = array::from_fn(BaseField::from);
let mut col = values.into_iter().collect::<BaseColumn>();
const CHUNK_SIZE: usize = 2;
let mut chunks = col.chunks_mut(CHUNK_SIZE);
chunks[2].set(19, BaseField::from(1234));
chunks[3].set(1, BaseField::from(5678));
assert_eq!(col.at(2 * CHUNK_SIZE * N_LANES + 19), BaseField::from(1234));
assert_eq!(col.at(3 * CHUNK_SIZE * N_LANES + 1), BaseField::from(5678));
}
#[test]
fn test_secure_column_by_coords_chunks_mut() {
const COL_PACKED_SIZE: usize = 16;
let a: [BaseField; N_LANES * COL_PACKED_SIZE] = array::from_fn(BaseField::from);
let b: [BaseField; N_LANES * COL_PACKED_SIZE] = array::from_fn(BaseField::from);
let c: [BaseField; N_LANES * COL_PACKED_SIZE] = array::from_fn(BaseField::from);
let d: [BaseField; N_LANES * COL_PACKED_SIZE] = array::from_fn(BaseField::from);
let mut col = SecureColumnByCoords {
columns: [a, b, c, d].map(|values| values.into_iter().collect::<BaseColumn>()),
};
let mut rng = SmallRng::seed_from_u64(0);
let rand0 = PackedQM31::from_array(rng.gen());
let rand1 = PackedQM31::from_array(rng.gen());
const CHUNK_SIZE: usize = 4;
let mut chunks = col.chunks_mut(CHUNK_SIZE);
unsafe {
chunks[2].set_packed(3, rand0);
chunks[3].set_packed(1, rand1);
assert_eq!(
col.packed_at(2 * CHUNK_SIZE + 3).to_array(),
rand0.to_array()
);
assert_eq!(
col.packed_at(3 * CHUNK_SIZE + 1).to_array(),
rand1.to_array()
);
}
}
}

View File

@ -0,0 +1,86 @@
use std::simd::{simd_swizzle, u32x2, Simd};
use super::m31::{PackedM31, LOG_N_LANES};
use crate::core::circle::{CirclePoint, M31_CIRCLE_LOG_ORDER};
use crate::core::fields::m31::M31;
use crate::core::poly::circle::CircleDomain;
use crate::core::utils::bit_reverse_index;
pub struct CircleDomainBitRevIterator {
domain: CircleDomain,
i: usize,
current: CirclePoint<PackedM31>,
flips: [CirclePoint<M31>; (M31_CIRCLE_LOG_ORDER - LOG_N_LANES) as usize],
}
impl CircleDomainBitRevIterator {
pub fn new(domain: CircleDomain) -> Self {
let log_size = domain.log_size();
assert!(log_size >= LOG_N_LANES);
let initial_points = std::array::from_fn(|i| domain.at(bit_reverse_index(i, log_size)));
let current = CirclePoint {
x: PackedM31::from_array(initial_points.each_ref().map(|p| p.x)),
y: PackedM31::from_array(initial_points.each_ref().map(|p| p.y)),
};
let mut flips = [CirclePoint::zero(); (M31_CIRCLE_LOG_ORDER - LOG_N_LANES) as usize];
for i in 0..(log_size - LOG_N_LANES) {
// L i
// 000111000000 ->
// 000000100000
let prev_mul = bit_reverse_index((1 << i) - 1, log_size - LOG_N_LANES);
let new_mul = bit_reverse_index(1 << i, log_size - LOG_N_LANES);
let flip = domain.half_coset.step.mul(new_mul as u128)
- domain.half_coset.step.mul(prev_mul as u128);
flips[i as usize] = flip;
}
Self {
domain,
i: 0,
current,
flips,
}
}
}
impl Iterator for CircleDomainBitRevIterator {
type Item = CirclePoint<PackedM31>;
fn next(&mut self) -> Option<Self::Item> {
if self.i << LOG_N_LANES >= self.domain.size() {
return None;
}
let res = self.current;
let flip = self.flips[self.i.trailing_ones() as usize];
let flipx = Simd::splat(flip.x.0);
let flipy = u32x2::from_array([flip.y.0, (-flip.y).0]);
let flipy = simd_swizzle!(flipy, [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]);
let flip = unsafe {
CirclePoint {
x: PackedM31::from_simd_unchecked(flipx),
y: PackedM31::from_simd_unchecked(flipy),
}
};
self.current = self.current + flip;
self.i += 1;
Some(res)
}
}
#[test]
fn test_circle_domain_bit_rev_iterator() {
let domain = CircleDomain::new(crate::core::circle::Coset::new(
crate::core::circle::CirclePointIndex::generator(),
5,
));
let mut expected = domain.iter().collect::<Vec<_>>();
crate::core::utils::bit_reverse(&mut expected);
let actual = CircleDomainBitRevIterator::new(domain)
.flat_map(|c| -> [_; 16] {
std::array::from_fn(|i| CirclePoint {
x: c.x.to_array()[i],
y: c.y.to_array()[i],
})
})
.collect::<Vec<_>>();
assert_eq!(actual, expected);
}

View File

@ -0,0 +1,712 @@
//! Inverse fft.
use std::simd::{simd_swizzle, u32x16, u32x2, u32x4};
use itertools::Itertools;
use super::{
compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE,
};
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
use crate::core::circle::Coset;
use crate::core::fields::FieldExpOps;
use crate::core::utils::bit_reverse;
/// Performs an Inverse Circle Fast Fourier Transform (ICFFT) on the given values.
///
/// # Arguments
///
/// - `values`: A mutable pointer to the values on which the ICFFT is to be performed.
/// - `twiddle_dbl`: A reference to the doubles of the twiddle factors.
/// - `log_n_elements`: The log of the number of elements in the `values` array.
///
/// # Panics
///
/// Panic if `log_n_elements` is less than [`MIN_FFT_LOG_SIZE`].
///
/// # Safety
///
/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`].
pub unsafe fn ifft(values: *mut u32, twiddle_dbl: &[&[u32]], log_n_elements: usize) {
assert!(log_n_elements >= MIN_FFT_LOG_SIZE as usize);
let log_n_vecs = log_n_elements - LOG_N_LANES as usize;
if log_n_elements <= CACHED_FFT_LOG_SIZE as usize {
ifft_lower_with_vecwise(values, twiddle_dbl, log_n_elements, log_n_elements);
return;
}
let fft_layers_pre_transpose = log_n_vecs.div_ceil(2);
let fft_layers_post_transpose = log_n_vecs / 2;
ifft_lower_with_vecwise(
values,
&twiddle_dbl[..3 + fft_layers_pre_transpose],
log_n_elements,
fft_layers_pre_transpose + LOG_N_LANES as usize,
);
transpose_vecs(values, log_n_vecs);
ifft_lower_without_vecwise(
values,
&twiddle_dbl[3 + fft_layers_pre_transpose..],
log_n_elements,
fft_layers_post_transpose,
);
}
/// Computes partial ifft on `2^log_size` M31 elements.
///
/// # Arguments
///
/// - `values`: Pointer to the entire value array, aligned to 64 bytes.
/// - `twiddle_dbl`: The doubles of the twiddle factors for each layer of the the ifft. Layer i
/// holds `2^(log_size - 1 - i)` twiddles.
/// - `log_size`: The log of the number of number of M31 elements in the array.
/// - `fft_layers`: The number of ifft layers to apply, out of log_size.
///
/// # Panics
///
/// Panics if `log_size` is not at least 5.
///
/// # Safety
///
/// `values` must have the same alignment as [`PackedBaseField`].
/// `fft_layers` must be at least 5.
pub unsafe fn ifft_lower_with_vecwise(
values: *mut u32,
twiddle_dbl: &[&[u32]],
log_size: usize,
fft_layers: usize,
) {
const VECWISE_FFT_BITS: usize = LOG_N_LANES as usize + 1;
assert!(log_size >= VECWISE_FFT_BITS);
assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2));
for index_h in 0..1 << (log_size - fft_layers) {
ifft_vecwise_loop(values, twiddle_dbl, fft_layers - VECWISE_FFT_BITS, index_h);
for layer in (VECWISE_FFT_BITS..fft_layers).step_by(3) {
match fft_layers - layer {
1 => {
ifft1_loop(values, &twiddle_dbl[(layer - 1)..], layer, index_h);
}
2 => {
ifft2_loop(values, &twiddle_dbl[(layer - 1)..], layer, index_h);
}
_ => {
ifft3_loop(
values,
&twiddle_dbl[(layer - 1)..],
fft_layers - layer - 3,
layer,
index_h,
);
}
}
}
}
}
/// Computes partial ifft on `2^log_size` M31 elements, skipping the vecwise layers (lower 4 bits of
/// the index).
///
/// # Arguments
///
/// - `values`: Pointer to the entire value array, aligned to 64 bytes.
/// - `twiddle_dbl`: The doubles of the twiddle factors for each layer of the the ifft.
/// - `log_size`: The log of the number of number of M31 elements in the array.
/// - `fft_layers`: The number of ifft layers to apply, out of `log_size - LOG_N_LANES`.
///
/// # Panics
///
/// Panics if `log_size` is not at least 4.
///
/// # Safety
///
/// `values` must have the same alignment as [`PackedBaseField`].
/// `fft_layers` must be at least 4.
pub unsafe fn ifft_lower_without_vecwise(
values: *mut u32,
twiddle_dbl: &[&[u32]],
log_size: usize,
fft_layers: usize,
) {
assert!(log_size >= LOG_N_LANES as usize);
for index_h in 0..1 << (log_size - fft_layers - LOG_N_LANES as usize) {
for layer in (0..fft_layers).step_by(3) {
let fixed_layer = layer + LOG_N_LANES as usize;
match fft_layers - layer {
1 => {
ifft1_loop(values, &twiddle_dbl[layer..], fixed_layer, index_h);
}
2 => {
ifft2_loop(values, &twiddle_dbl[layer..], fixed_layer, index_h);
}
_ => {
ifft3_loop(
values,
&twiddle_dbl[layer..],
fft_layers - layer - 3,
fixed_layer,
index_h,
);
}
}
}
}
}
/// Runs the first 5 ifft layers across the entire array.
///
/// # Arguments
///
/// - `values`: Pointer to the entire value array, aligned to 64 bytes.
/// - `twiddle_dbl`: The doubles of the twiddle factors for each of the 5 ifft layers.
/// - `high_bits`: The number of bits this loops needs to run on.
/// - `index_h`: The higher part of the index, iterated by the caller.
///
/// # Safety
///
/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`].
pub unsafe fn ifft_vecwise_loop(
values: *mut u32,
twiddle_dbl: &[&[u32]],
loop_bits: usize,
index_h: usize,
) {
for index_l in 0..1 << loop_bits {
let index = (index_h << loop_bits) + index_l;
let mut val0 = PackedBaseField::load(values.add(index * 32).cast_const());
let mut val1 = PackedBaseField::load(values.add(index * 32 + 16).cast_const());
(val0, val1) = vecwise_ibutterflies(
val0,
val1,
std::array::from_fn(|i| *twiddle_dbl[0].get_unchecked(index * 8 + i)),
std::array::from_fn(|i| *twiddle_dbl[1].get_unchecked(index * 4 + i)),
std::array::from_fn(|i| *twiddle_dbl[2].get_unchecked(index * 2 + i)),
);
(val0, val1) = simd_ibutterfly(
val0,
val1,
u32x16::splat(*twiddle_dbl[3].get_unchecked(index)),
);
val0.store(values.add(index * 32));
val1.store(values.add(index * 32 + 16));
}
}
/// Runs 3 ifft layers across the entire array.
///
/// # Arguments
///
/// - `values`: Pointer to the entire value array, aligned to 64 bytes.
/// - `twiddle_dbl`: The doubles of the twiddle factors for each of the 3 ifft layers.
/// - `loop_bits`: The number of bits this loops needs to run on.
/// - `layer`: The layer number of the first ifft layer to apply. The layers `layer`, `layer + 1`,
/// `layer + 2` are applied.
/// - `index_h`: The higher part of the index, iterated by the caller.
///
/// # Safety
///
/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`].
pub unsafe fn ifft3_loop(
values: *mut u32,
twiddle_dbl: &[&[u32]],
loop_bits: usize,
layer: usize,
index_h: usize,
) {
for index_l in 0..1 << loop_bits {
let index = (index_h << loop_bits) + index_l;
let offset = index << (layer + 3);
for l in (0..1 << layer).step_by(1 << LOG_N_LANES as usize) {
ifft3(
values,
offset + l,
layer,
std::array::from_fn(|i| {
*twiddle_dbl[0].get_unchecked((index * 4 + i) & (twiddle_dbl[0].len() - 1))
}),
std::array::from_fn(|i| {
*twiddle_dbl[1].get_unchecked((index * 2 + i) & (twiddle_dbl[1].len() - 1))
}),
std::array::from_fn(|i| {
*twiddle_dbl[2].get_unchecked((index + i) & (twiddle_dbl[2].len() - 1))
}),
);
}
}
}
/// Runs 2 ifft layers across the entire array.
///
/// # Arguments
///
/// - `values`: Pointer to the entire value array, aligned to 64 bytes.
/// - `twiddle_dbl`: The doubles of the twiddle factors for each of the 2 ifft layers.
/// - `loop_bits`: The number of bits this loops needs to run on.
/// - `layer`: The layer number of the first ifft layer to apply. The layers `layer`, `layer + 1`
/// are applied.
/// - `index`: The index, iterated by the caller.
///
/// # Safety
///
/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`].
unsafe fn ifft2_loop(values: *mut u32, twiddle_dbl: &[&[u32]], layer: usize, index: usize) {
let offset = index << (layer + 2);
for l in (0..1 << layer).step_by(1 << LOG_N_LANES as usize) {
ifft2(
values,
offset + l,
layer,
std::array::from_fn(|i| {
*twiddle_dbl[0].get_unchecked((index * 2 + i) & (twiddle_dbl[0].len() - 1))
}),
std::array::from_fn(|i| {
*twiddle_dbl[1].get_unchecked((index + i) & (twiddle_dbl[1].len() - 1))
}),
);
}
}
/// Runs 1 ifft layer across the entire array.
///
/// # Arguments
///
/// - `values`: Pointer to the entire value array, aligned to 64 bytes.
/// - `twiddle_dbl`: The doubles of the twiddle factors for the ifft layer.
/// - `layer`: The layer number of the ifft layer to apply.
/// - `index_h`: The higher part of the index, iterated by the caller.
///
/// # Safety
///
/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`].
unsafe fn ifft1_loop(values: *mut u32, twiddle_dbl: &[&[u32]], layer: usize, index: usize) {
let offset = index << (layer + 1);
for l in (0..1 << layer).step_by(1 << LOG_N_LANES as usize) {
ifft1(
values,
offset + l,
layer,
std::array::from_fn(|i| {
*twiddle_dbl[0].get_unchecked((index + i) & (twiddle_dbl[0].len() - 1))
}),
);
}
}
/// Computes the ibutterfly operation for packed M31 elements.
///
/// Returns `val0 + val1, t (val0 - val1)`. `val0, val1` are packed M31 elements. 16 M31 words at
/// each. Each value is assumed to be in unreduced form, [0, P] including P. `twiddle_dbl` holds 16
/// values, each is a *double* of a twiddle factor, in unreduced form.
pub fn simd_ibutterfly(
val0: PackedBaseField,
val1: PackedBaseField,
twiddle_dbl: u32x16,
) -> (PackedBaseField, PackedBaseField) {
let r0 = val0 + val1;
let r1 = val0 - val1;
let prod = mul_twiddle(r1, twiddle_dbl);
(r0, prod)
}
/// Runs ifft on 2 vectors of 16 M31 elements.
///
/// This amounts to 4 butterfly layers, each with 16 butterflies.
/// Each of the vectors represents a bit reversed evaluation.
/// Each value in a vectors is in unreduced form: [0, P] including P.
/// Takes 3 twiddle arrays, one for each layer after the first, holding the double of the
/// corresponding twiddle.
/// The first layer's twiddles (lower bit of the index) are computed from the second layer's
/// twiddles. The second layer takes 8 twiddles.
/// The third layer takes 4 twiddles.
/// The fourth layer takes 2 twiddles.
pub fn vecwise_ibutterflies(
mut val0: PackedBaseField,
mut val1: PackedBaseField,
twiddle1_dbl: [u32; 8],
twiddle2_dbl: [u32; 4],
twiddle3_dbl: [u32; 2],
) -> (PackedBaseField, PackedBaseField) {
// TODO(spapini): The permute can be fused with the _mm512_srli_epi64 inside the butterfly.
// Each `ibutterfly` take 2 512-bit registers, and does 16 butterflies element by element.
// We need to permute the 512-bit registers to get the right order for the butterflies.
// Denote the index of the 16 M31 elements in register i as i:abcd.
// At each layer we apply the following permutation to the index:
// i:abcd => d:iabc
// This is how it looks like at each iteration.
// i:abcd
// d:iabc
// ifft on d
// c:diab
// ifft on c
// b:cdia
// ifft on b
// a:bcid
// ifft on a
// i:abcd
let (t0, t1) = compute_first_twiddles(twiddle1_dbl.into());
// Apply the permutation, resulting in indexing d:iabc.
(val0, val1) = val0.deinterleave(val1);
(val0, val1) = simd_ibutterfly(val0, val1, t0);
// Apply the permutation, resulting in indexing c:diab.
(val0, val1) = val0.deinterleave(val1);
(val0, val1) = simd_ibutterfly(val0, val1, t1);
let t = simd_swizzle!(
u32x4::from(twiddle2_dbl),
[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
);
// Apply the permutation, resulting in indexing b:cdia.
(val0, val1) = val0.deinterleave(val1);
(val0, val1) = simd_ibutterfly(val0, val1, t);
let t = simd_swizzle!(
u32x2::from(twiddle3_dbl),
[0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
);
// Apply the permutation, resulting in indexing a:bcid.
(val0, val1) = val0.deinterleave(val1);
(val0, val1) = simd_ibutterfly(val0, val1, t);
// Apply the permutation, resulting in indexing i:abcd.
val0.deinterleave(val1)
}
/// Returns the line twiddles (x points) for an ifft on a coset.
pub fn get_itwiddle_dbls(mut coset: Coset) -> Vec<Vec<u32>> {
let mut res = vec![];
for _ in 0..coset.log_size() {
res.push(
coset
.iter()
.take(coset.size() / 2)
.map(|p| p.x.inverse().0 * 2)
.collect_vec(),
);
bit_reverse(res.last_mut().unwrap());
coset = coset.double();
}
res
}
/// Applies 3 ibutterfly layers on 8 vectors of 16 M31 elements.
///
/// Vectorized over the 16 elements of the vectors.
/// Used for radix-8 ifft.
/// Each butterfly layer, has 3 SIMD butterflies.
/// Total of 12 SIMD butterflies.
///
/// # Arguments
///
/// - `values`: Pointer to the entire value array.
/// - `offset`: The offset of the first value in the array.
/// - `log_step`: The log of the distance in the array, in M31 elements, between each pair of values
/// that need to be transformed. For layer i this is i - 4.
/// - `twiddles_dbl0/1/2`: The double of the twiddles for the 3 layers of ibutterflies. Each layer
/// has 4/2/1 twiddles.
///
/// # Safety
///
/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`].
pub unsafe fn ifft3(
values: *mut u32,
offset: usize,
log_step: usize,
twiddles_dbl0: [u32; 4],
twiddles_dbl1: [u32; 2],
twiddles_dbl2: [u32; 1],
) {
// Load the 8 SIMD vectors from the array.
let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const());
let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const());
let mut val2 = PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const());
let mut val3 = PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const());
let mut val4 = PackedBaseField::load(values.add(offset + (4 << log_step)).cast_const());
let mut val5 = PackedBaseField::load(values.add(offset + (5 << log_step)).cast_const());
let mut val6 = PackedBaseField::load(values.add(offset + (6 << log_step)).cast_const());
let mut val7 = PackedBaseField::load(values.add(offset + (7 << log_step)).cast_const());
// Apply the first layer of ibutterflies.
(val0, val1) = simd_ibutterfly(val0, val1, u32x16::splat(twiddles_dbl0[0]));
(val2, val3) = simd_ibutterfly(val2, val3, u32x16::splat(twiddles_dbl0[1]));
(val4, val5) = simd_ibutterfly(val4, val5, u32x16::splat(twiddles_dbl0[2]));
(val6, val7) = simd_ibutterfly(val6, val7, u32x16::splat(twiddles_dbl0[3]));
// Apply the second layer of ibutterflies.
(val0, val2) = simd_ibutterfly(val0, val2, u32x16::splat(twiddles_dbl1[0]));
(val1, val3) = simd_ibutterfly(val1, val3, u32x16::splat(twiddles_dbl1[0]));
(val4, val6) = simd_ibutterfly(val4, val6, u32x16::splat(twiddles_dbl1[1]));
(val5, val7) = simd_ibutterfly(val5, val7, u32x16::splat(twiddles_dbl1[1]));
// Apply the third layer of ibutterflies.
(val0, val4) = simd_ibutterfly(val0, val4, u32x16::splat(twiddles_dbl2[0]));
(val1, val5) = simd_ibutterfly(val1, val5, u32x16::splat(twiddles_dbl2[0]));
(val2, val6) = simd_ibutterfly(val2, val6, u32x16::splat(twiddles_dbl2[0]));
(val3, val7) = simd_ibutterfly(val3, val7, u32x16::splat(twiddles_dbl2[0]));
// Store the 8 SIMD vectors back to the array.
val0.store(values.add(offset + (0 << log_step)));
val1.store(values.add(offset + (1 << log_step)));
val2.store(values.add(offset + (2 << log_step)));
val3.store(values.add(offset + (3 << log_step)));
val4.store(values.add(offset + (4 << log_step)));
val5.store(values.add(offset + (5 << log_step)));
val6.store(values.add(offset + (6 << log_step)));
val7.store(values.add(offset + (7 << log_step)));
}
/// Applies 2 ibutterfly layers on 4 vectors of 16 M31 elements.
///
/// Vectorized over the 16 elements of the vectors.
/// Used for radix-4 ifft.
/// Each ibutterfly layer, has 2 SIMD butterflies.
/// Total of 4 SIMD butterflies.
///
/// # Arguments
///
/// - `values`: Pointer to the entire value array.
/// - `offset`: The offset of the first value in the array.
/// - `log_step`: The log of the distance in the array, in M31 elements, between each pair of values
/// that need to be transformed. For layer `i` this is `i - 4`.
/// - `twiddles_dbl0/1`: The double of the twiddles for the 2 layers of ibutterflies. Each layer has
/// 2/1 twiddles.
///
/// # Safety
///
/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`].
pub unsafe fn ifft2(
values: *mut u32,
offset: usize,
log_step: usize,
twiddles_dbl0: [u32; 2],
twiddles_dbl1: [u32; 1],
) {
// Load the 4 SIMD vectors from the array.
let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const());
let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const());
let mut val2 = PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const());
let mut val3 = PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const());
// Apply the first layer of butterflies.
(val0, val1) = simd_ibutterfly(val0, val1, u32x16::splat(twiddles_dbl0[0]));
(val2, val3) = simd_ibutterfly(val2, val3, u32x16::splat(twiddles_dbl0[1]));
// Apply the second layer of butterflies.
(val0, val2) = simd_ibutterfly(val0, val2, u32x16::splat(twiddles_dbl1[0]));
(val1, val3) = simd_ibutterfly(val1, val3, u32x16::splat(twiddles_dbl1[0]));
// Store the 4 SIMD vectors back to the array.
val0.store(values.add(offset + (0 << log_step)));
val1.store(values.add(offset + (1 << log_step)));
val2.store(values.add(offset + (2 << log_step)));
val3.store(values.add(offset + (3 << log_step)));
}
/// Applies 1 ibutterfly layers on 2 vectors of 16 M31 elements.
///
/// Vectorized over the 16 elements of the vectors.
///
/// # Arguments
///
/// - `values`: Pointer to the entire value array.
/// - `offset`: The offset of the first value in the array.
/// - `log_step`: The log of the distance in the array, in M31 elements, between each pair of values
/// that need to be transformed. For layer `i` this is `i - 4`.
/// - `twiddles_dbl0`: The double of the twiddles for the ibutterfly layer.
///
/// # Safety
///
/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`].
pub unsafe fn ifft1(values: *mut u32, offset: usize, log_step: usize, twiddles_dbl0: [u32; 1]) {
// Load the 2 SIMD vectors from the array.
let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const());
let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const());
(val0, val1) = simd_ibutterfly(val0, val1, u32x16::splat(twiddles_dbl0[0]));
// Store the 2 SIMD vectors back to the array.
val0.store(values.add(offset + (0 << log_step)));
val1.store(values.add(offset + (1 << log_step)));
}
#[cfg(test)]
mod tests {
use std::mem::transmute;
use std::simd::u32x16;
use itertools::Itertools;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use super::{
get_itwiddle_dbls, ifft, ifft3, ifft_lower_with_vecwise, simd_ibutterfly,
vecwise_ibutterflies,
};
use crate::core::backend::cpu::CpuCircleEvaluation;
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::simd::fft::{transpose_vecs, CACHED_FFT_LOG_SIZE};
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES, N_LANES};
use crate::core::backend::Column;
use crate::core::fft::ibutterfly as ground_truth_ibutterfly;
use crate::core::fields::m31::BaseField;
use crate::core::poly::circle::{CanonicCoset, CircleDomain};
#[test]
fn test_ibutterfly() {
let mut rng = SmallRng::seed_from_u64(0);
let mut v0: [BaseField; N_LANES] = rng.gen();
let mut v1: [BaseField; N_LANES] = rng.gen();
let twiddle: [BaseField; N_LANES] = rng.gen();
let twiddle_dbl = twiddle.map(|v| v.0 * 2);
let (r0, r1) = simd_ibutterfly(v0.into(), v1.into(), twiddle_dbl.into());
let r0 = r0.to_array();
let r1 = r1.to_array();
for i in 0..N_LANES {
ground_truth_ibutterfly(&mut v0[i], &mut v1[i], twiddle[i]);
assert_eq!((v0[i], v1[i]), (r0[i], r1[i]), "mismatch at i={i}");
}
}
#[test]
fn test_ifft3() {
let mut rng = SmallRng::seed_from_u64(0);
let values = rng.gen::<[BaseField; 8]>().map(PackedBaseField::broadcast);
let twiddles0: [BaseField; 4] = rng.gen();
let twiddles1: [BaseField; 2] = rng.gen();
let twiddles2: [BaseField; 1] = rng.gen();
let twiddles0_dbl = twiddles0.map(|v| v.0 * 2);
let twiddles1_dbl = twiddles1.map(|v| v.0 * 2);
let twiddles2_dbl = twiddles2.map(|v| v.0 * 2);
let mut res = values;
unsafe {
ifft3(
transmute(res.as_mut_ptr()),
0,
LOG_N_LANES as usize,
twiddles0_dbl,
twiddles1_dbl,
twiddles2_dbl,
)
};
let mut expected = values.map(|v| v.to_array()[0]);
for i in 0..8 {
let j = i ^ 1;
if i > j {
continue;
}
let (mut v0, mut v1) = (expected[i], expected[j]);
ground_truth_ibutterfly(&mut v0, &mut v1, twiddles0[i / 2]);
(expected[i], expected[j]) = (v0, v1);
}
for i in 0..8 {
let j = i ^ 2;
if i > j {
continue;
}
let (mut v0, mut v1) = (expected[i], expected[j]);
ground_truth_ibutterfly(&mut v0, &mut v1, twiddles1[i / 4]);
(expected[i], expected[j]) = (v0, v1);
}
for i in 0..8 {
let j = i ^ 4;
if i > j {
continue;
}
let (mut v0, mut v1) = (expected[i], expected[j]);
ground_truth_ibutterfly(&mut v0, &mut v1, twiddles2[0]);
(expected[i], expected[j]) = (v0, v1);
}
for i in 0..8 {
assert_eq!(
res[i].to_array(),
[expected[i]; N_LANES],
"mismatch at i={i}"
);
}
}
#[test]
fn test_vecwise_ibutterflies() {
let domain = CanonicCoset::new(5).circle_domain();
let twiddle_dbls = get_itwiddle_dbls(domain.half_coset);
assert_eq!(twiddle_dbls.len(), 4);
let mut rng = SmallRng::seed_from_u64(0);
let values: [[BaseField; 16]; 2] = rng.gen();
let res = {
let (val0, val1) = vecwise_ibutterflies(
values[0].into(),
values[1].into(),
twiddle_dbls[0].clone().try_into().unwrap(),
twiddle_dbls[1].clone().try_into().unwrap(),
twiddle_dbls[2].clone().try_into().unwrap(),
);
let (val0, val1) = simd_ibutterfly(val0, val1, u32x16::splat(twiddle_dbls[3][0]));
[val0.to_array(), val1.to_array()].concat()
};
assert_eq!(res, ground_truth_ifft(domain, values.flatten()));
}
#[test]
fn test_ifft_lower_with_vecwise() {
for log_size in 5..12 {
let domain = CanonicCoset::new(log_size).circle_domain();
let mut rng = SmallRng::seed_from_u64(0);
let values = (0..domain.size()).map(|_| rng.gen()).collect_vec();
let twiddle_dbls = get_itwiddle_dbls(domain.half_coset);
let mut res = values.iter().copied().collect::<BaseColumn>();
unsafe {
ifft_lower_with_vecwise(
transmute(res.data.as_mut_ptr()),
&twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(),
log_size as usize,
log_size as usize,
);
}
assert_eq!(res.to_cpu(), ground_truth_ifft(domain, &values));
}
}
#[test]
fn test_ifft_full() {
for log_size in CACHED_FFT_LOG_SIZE + 1..CACHED_FFT_LOG_SIZE + 3 {
let domain = CanonicCoset::new(log_size).circle_domain();
let mut rng = SmallRng::seed_from_u64(0);
let values = (0..domain.size()).map(|_| rng.gen()).collect_vec();
let twiddle_dbls = get_itwiddle_dbls(domain.half_coset);
let mut res = values.iter().copied().collect::<BaseColumn>();
unsafe {
ifft(
transmute(res.data.as_mut_ptr()),
&twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(),
log_size as usize,
);
transpose_vecs(transmute(res.data.as_mut_ptr()), log_size as usize - 4);
}
assert_eq!(res.to_cpu(), ground_truth_ifft(domain, &values));
}
}
fn ground_truth_ifft(domain: CircleDomain, values: &[BaseField]) -> Vec<BaseField> {
let eval = CpuCircleEvaluation::new(domain, values.to_vec());
let mut res = eval.interpolate().coeffs;
let denorm = BaseField::from(domain.size());
res.iter_mut().for_each(|v| *v *= denorm);
res
}
}

View File

@ -0,0 +1,120 @@
use std::simd::{simd_swizzle, u32x16, u32x8};
use super::m31::PackedBaseField;
use crate::core::fields::m31::P;
pub mod ifft;
pub mod rfft;
pub const CACHED_FFT_LOG_SIZE: u32 = 16;
pub const MIN_FFT_LOG_SIZE: u32 = 5;
// TODO(spapini): FFTs return a redundant representation, that can get the value P. need to reduce
// it somewhere.
/// Transposes the SIMD vectors in the given array.
///
/// Swaps the bit index abc <-> cba, where |a|=|c| and |b| = 0 or 1, according to the parity of
/// `log_n_vecs`.
/// When log_n_vecs is odd, transforms the index abc <-> cba, w
///
/// # Arguments
///
/// - `values`: A mutable pointer to the values that are to be transposed.
/// - `log_n_vecs`: The log of the number of SIMD vectors in the `values` array.
///
/// # Safety
///
/// Behavior is undefined if `values` does not have the same alignment as [`u32x16`].
pub unsafe fn transpose_vecs(values: *mut u32, log_n_vecs: usize) {
let half = log_n_vecs / 2;
for b in 0..1 << (log_n_vecs & 1) {
for a in 0..1 << half {
for c in 0..1 << half {
let i = (a << (log_n_vecs - half)) | (b << half) | c;
let j = (c << (log_n_vecs - half)) | (b << half) | a;
if i >= j {
continue;
}
let val0 = load(values.add(i << 4).cast_const());
let val1 = load(values.add(j << 4).cast_const());
store(values.add(i << 4), val1);
store(values.add(j << 4), val0);
}
}
}
}
/// Computes the twiddles for the first fft layer from the second, and loads both to SIMD registers.
///
/// Returns the twiddles for the first layer and the twiddles for the second layer.
pub fn compute_first_twiddles(twiddle1_dbl: u32x8) -> (u32x16, u32x16) {
// Start by loading the twiddles for the second layer (layer 1):
let t1 = simd_swizzle!(
twiddle1_dbl,
twiddle1_dbl,
[0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]
);
// The twiddles for layer 0 can be computed from the twiddles for layer 1.
// Since the twiddles are bit reversed, we consider the circle domain in bit reversed order.
// Each consecutive 4 points in the bit reversed order of a coset form a circle coset of size 4.
// A circle coset of size 4 in bit reversed order looks like this:
// [(x, y), (-x, -y), (y, -x), (-y, x)]
// Note: This is related to the choice of M31_CIRCLE_GEN, and the fact the a quarter rotation
// is (0,-1) and not (0,1). (0,1) would yield another relation.
// The twiddles for layer 0 are the y coordinates:
// [y, -y, -x, x]
// The twiddles for layer 1 in bit reversed order are the x coordinates:
// [x, y]
// Works also for inverse of the twiddles.
// The twiddles for layer 0 are computed like this:
// t0[4i:4i+3] = [t1[2i+1], -t1[2i+1], -t1[2i], t1[2i]]
// Xoring a double twiddle with P*2 transforms it to the double of it negation.
// Note that this keeps the values as a double of a value in the range [0, P].
const P2: u32 = P * 2;
const NEGATION_MASK: u32x16 =
u32x16::from_array([0, P2, P2, 0, 0, P2, P2, 0, 0, P2, P2, 0, 0, P2, P2, 0]);
let t0 = simd_swizzle!(
t1,
[
0b0001, 0b0001, 0b0000, 0b0000, 0b0011, 0b0011, 0b0010, 0b0010, 0b0101, 0b0101, 0b0100,
0b0100, 0b0111, 0b0111, 0b0110, 0b0110,
]
) ^ NEGATION_MASK;
(t0, t1)
}
#[inline]
unsafe fn load(mem_addr: *const u32) -> u32x16 {
std::ptr::read(mem_addr as *const u32x16)
}
#[inline]
unsafe fn store(mem_addr: *mut u32, a: u32x16) {
std::ptr::write(mem_addr as *mut u32x16, a);
}
/// Computes `v * twiddle`
fn mul_twiddle(v: PackedBaseField, twiddle_dbl: u32x16) -> PackedBaseField {
// TODO: Come up with a better approach than `cfg`ing on target_feature.
// TODO: Ensure all these branches get tested in the CI.
cfg_if::cfg_if! {
if #[cfg(all(target_feature = "neon", target_arch = "aarch64"))] {
// TODO: For architectures that when multiplying require doubling then the twiddles
// should be precomputed as double. For other architectures, the twiddle should be
// precomputed without doubling.
crate::core::backend::simd::m31::_mul_doubled_neon(v, twiddle_dbl)
} else if #[cfg(all(target_feature = "simd128", target_arch = "wasm32"))] {
crate::core::backend::simd::m31::_mul_doubled_wasm(v, twiddle_dbl)
} else if #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] {
crate::core::backend::simd::m31::_mul_doubled_avx512(v, twiddle_dbl)
} else if #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] {
crate::core::backend::simd::m31::_mul_doubled_avx2(v, twiddle_dbl)
} else {
crate::core::backend::simd::m31::_mul_doubled_simd(v, twiddle_dbl)
}
}
}

View File

@ -0,0 +1,742 @@
//! Regular (forward) fft.
use std::array;
use std::simd::{simd_swizzle, u32x16, u32x2, u32x4, u32x8};
use itertools::Itertools;
use super::{
compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE,
};
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
use crate::core::circle::Coset;
use crate::core::utils::bit_reverse;
/// Performs a Circle Fast Fourier Transform (CFFT) on the given values.
///
/// # Arguments
///
/// * `src`: A pointer to the values to transform.
/// * `dst`: A pointer to the destination array.
/// * `twiddle_dbl`: A reference to the doubles of the twiddle factors.
/// * `log_n_elements`: The log of the number of elements in the `values` array.
///
/// # Panics
///
/// This function will panic if `log_n_elements` is less than `MIN_FFT_LOG_SIZE`.
///
/// # Safety
///
/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`].
pub unsafe fn fft(src: *const u32, dst: *mut u32, twiddle_dbl: &[&[u32]], log_n_elements: usize) {
assert!(log_n_elements >= MIN_FFT_LOG_SIZE as usize);
let log_n_vecs = log_n_elements - LOG_N_LANES as usize;
if log_n_elements <= CACHED_FFT_LOG_SIZE as usize {
fft_lower_with_vecwise(src, dst, twiddle_dbl, log_n_elements, log_n_elements);
return;
}
let fft_layers_pre_transpose = log_n_vecs.div_ceil(2);
let fft_layers_post_transpose = log_n_vecs / 2;
fft_lower_without_vecwise(
src,
dst,
&twiddle_dbl[(3 + fft_layers_pre_transpose)..],
log_n_elements,
fft_layers_post_transpose,
);
transpose_vecs(dst, log_n_vecs);
fft_lower_with_vecwise(
dst,
dst,
&twiddle_dbl[..3 + fft_layers_pre_transpose],
log_n_elements,
fft_layers_pre_transpose + LOG_N_LANES as usize,
);
}
/// Computes partial fft on `2^log_size` M31 elements.
///
/// # Arguments
///
/// - `src`: A pointer to the values to transform, aligned to 64 bytes.
/// - `dst`: A pointer to the destination array, aligned to 64 bytes.
/// - `twiddle_dbl`: The doubles of the twiddle factors for each layer of the the fft. Layer `i`
/// holds `2^(log_size - 1 - i)` twiddles.
/// - `log_size`: The log of the number of number of M31 elements in the array.
/// - `fft_layers`: The number of fft layers to apply, out of log_size.
///
/// # Panics
///
/// Panics if `log_size` is not at least 5.
///
/// # Safety
///
/// `src` and `dst` must have same alignment as [`PackedBaseField`].
/// `fft_layers` must be at least 5.
pub unsafe fn fft_lower_with_vecwise(
src: *const u32,
dst: *mut u32,
twiddle_dbl: &[&[u32]],
log_size: usize,
fft_layers: usize,
) {
const VECWISE_FFT_BITS: usize = LOG_N_LANES as usize + 1;
assert!(log_size >= VECWISE_FFT_BITS);
assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2));
for index_h in 0..1 << (log_size - fft_layers) {
let mut src = src;
for layer in (VECWISE_FFT_BITS..fft_layers).step_by(3).rev() {
match fft_layers - layer {
1 => {
fft1_loop(src, dst, &twiddle_dbl[(layer - 1)..], layer, index_h);
}
2 => {
fft2_loop(src, dst, &twiddle_dbl[(layer - 1)..], layer, index_h);
}
_ => {
fft3_loop(
src,
dst,
&twiddle_dbl[(layer - 1)..],
fft_layers - layer - 3,
layer,
index_h,
);
}
}
src = dst;
}
fft_vecwise_loop(
src,
dst,
twiddle_dbl,
fft_layers - VECWISE_FFT_BITS,
index_h,
);
}
}
/// Computes partial fft on `2^log_size` M31 elements, skipping the vecwise layers (lower 4 bits of
/// the index).
///
/// # Arguments
///
/// - `src`: A pointer to the values to transform, aligned to 64 bytes.
/// - `dst`: A pointer to the destination array, aligned to 64 bytes.
/// - `twiddle_dbl`: The doubles of the twiddle factors for each layer of the the fft.
/// - `log_size`: The log of the number of number of M31 elements in the array.
/// - `fft_layers`: The number of fft layers to apply, out of log_size - VEC_LOG_SIZE.
///
/// # Panics
///
/// Panics if `log_size` is not at least 4.
///
/// # Safety
///
/// `src` and `dst` must have same alignment as [`PackedBaseField`].
/// `fft_layers` must be at least 4.
pub unsafe fn fft_lower_without_vecwise(
src: *const u32,
dst: *mut u32,
twiddle_dbl: &[&[u32]],
log_size: usize,
fft_layers: usize,
) {
assert!(log_size >= LOG_N_LANES as usize);
for index_h in 0..1 << (log_size - fft_layers - LOG_N_LANES as usize) {
let mut src = src;
for layer in (0..fft_layers).step_by(3).rev() {
let fixed_layer = layer + LOG_N_LANES as usize;
match fft_layers - layer {
1 => {
fft1_loop(src, dst, &twiddle_dbl[layer..], fixed_layer, index_h);
}
2 => {
fft2_loop(src, dst, &twiddle_dbl[layer..], fixed_layer, index_h);
}
_ => {
fft3_loop(
src,
dst,
&twiddle_dbl[layer..],
fft_layers - layer - 3,
fixed_layer,
index_h,
);
}
}
src = dst;
}
}
}
/// Runs the last 5 fft layers across the entire array.
///
/// # Arguments
///
/// - `src`: A pointer to the values to transform, aligned to 64 bytes.
/// - `dst`: A pointer to the destination array, aligned to 64 bytes.
/// - `twiddle_dbl`: The doubles of the twiddle factors for each of the 5 fft layers.
/// - `high_bits`: The number of bits this loops needs to run on.
/// - `index_h`: The higher part of the index, iterated by the caller.
///
/// # Safety
///
/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`].
unsafe fn fft_vecwise_loop(
src: *const u32,
dst: *mut u32,
twiddle_dbl: &[&[u32]],
loop_bits: usize,
index_h: usize,
) {
for index_l in 0..1 << loop_bits {
let index = (index_h << loop_bits) + index_l;
let mut val0 = PackedBaseField::load(src.add(index * 32));
let mut val1 = PackedBaseField::load(src.add(index * 32 + 16));
(val0, val1) = simd_butterfly(
val0,
val1,
u32x16::splat(*twiddle_dbl[3].get_unchecked(index)),
);
(val0, val1) = vecwise_butterflies(
val0,
val1,
array::from_fn(|i| *twiddle_dbl[0].get_unchecked(index * 8 + i)),
array::from_fn(|i| *twiddle_dbl[1].get_unchecked(index * 4 + i)),
array::from_fn(|i| *twiddle_dbl[2].get_unchecked(index * 2 + i)),
);
val0.store(dst.add(index * 32));
val1.store(dst.add(index * 32 + 16));
}
}
/// Runs 3 fft layers across the entire array.
///
/// # Arguments
///
/// - `src`: A pointer to the values to transform, aligned to 64 bytes.
/// - `dst`: A pointer to the destination array, aligned to 64 bytes.
/// - `twiddle_dbl`: The doubles of the twiddle factors for each of the 3 fft layers.
/// - `loop_bits`: The number of bits this loops needs to run on.
/// - `layer`: The layer number of the first fft layer to apply. The layers `layer`, `layer + 1`,
/// `layer + 2` are applied.
/// - `index_h`: The higher part of the index, iterated by the caller.
///
/// # Safety
///
/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`].
unsafe fn fft3_loop(
src: *const u32,
dst: *mut u32,
twiddle_dbl: &[&[u32]],
loop_bits: usize,
layer: usize,
index_h: usize,
) {
for index_l in 0..1 << loop_bits {
let index = (index_h << loop_bits) + index_l;
let offset = index << (layer + 3);
for l in (0..1 << layer).step_by(1 << LOG_N_LANES as usize) {
fft3(
src,
dst,
offset + l,
layer,
array::from_fn(|i| {
*twiddle_dbl[0].get_unchecked((index * 4 + i) & (twiddle_dbl[0].len() - 1))
}),
array::from_fn(|i| {
*twiddle_dbl[1].get_unchecked((index * 2 + i) & (twiddle_dbl[1].len() - 1))
}),
array::from_fn(|i| {
*twiddle_dbl[2].get_unchecked((index + i) & (twiddle_dbl[2].len() - 1))
}),
);
}
}
}
/// Runs 2 fft layers across the entire array.
///
/// # Arguments
///
/// - `src`: A pointer to the values to transform, aligned to 64 bytes.
/// - `dst`: A pointer to the destination array, aligned to 64 bytes.
/// - `twiddle_dbl`: The doubles of the twiddle factors for each of the 2 fft layers.
/// - `loop_bits`: The number of bits this loops needs to run on.
/// - `layer`: The layer number of the first fft layer to apply. The layers `layer`, `layer + 1` are
/// applied.
/// - `index`: The index, iterated by the caller.
///
/// # Safety
///
/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`].
unsafe fn fft2_loop(
src: *const u32,
dst: *mut u32,
twiddle_dbl: &[&[u32]],
layer: usize,
index: usize,
) {
let offset = index << (layer + 2);
for l in (0..1 << layer).step_by(1 << LOG_N_LANES as usize) {
fft2(
src,
dst,
offset + l,
layer,
array::from_fn(|i| {
*twiddle_dbl[0].get_unchecked((index * 2 + i) & (twiddle_dbl[0].len() - 1))
}),
array::from_fn(|i| {
*twiddle_dbl[1].get_unchecked((index + i) & (twiddle_dbl[1].len() - 1))
}),
);
}
}
/// Runs 1 fft layer across the entire array.
///
/// # Arguments
///
/// - `src`: A pointer to the values to transform, aligned to 64 bytes.
/// - `dst`: A pointer to the destination array, aligned to 64 bytes.
/// - `twiddle_dbl`: The doubles of the twiddle factors for the fft layer.
/// - `layer`: The layer number of the fft layer to apply.
/// - `index_h`: The higher part of the index, iterated by the caller.
///
/// # Safety
///
/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`].
unsafe fn fft1_loop(
src: *const u32,
dst: *mut u32,
twiddle_dbl: &[&[u32]],
layer: usize,
index: usize,
) {
let offset = index << (layer + 1);
for l in (0..1 << layer).step_by(1 << LOG_N_LANES as usize) {
fft1(
src,
dst,
offset + l,
layer,
array::from_fn(|i| {
*twiddle_dbl[0].get_unchecked((index + i) & (twiddle_dbl[0].len() - 1))
}),
);
}
}
/// Computes the butterfly operation for packed M31 elements.
///
/// Returns `val0 + t val1, val0 - t val1`. `val0, val1` are packed M31 elements. 16 M31 words at
/// each. Each value is assumed to be in unreduced form, [0, P] including P. Returned values are in
/// unreduced form, [0, P] including P. twiddle_dbl holds 16 values, each is a *double* of a twiddle
/// factor, in unreduced form, [0, 2*P].
pub fn simd_butterfly(
val0: PackedBaseField,
val1: PackedBaseField,
twiddle_dbl: u32x16,
) -> (PackedBaseField, PackedBaseField) {
let prod = mul_twiddle(val1, twiddle_dbl);
(val0 + prod, val0 - prod)
}
/// Runs fft on 2 vectors of 16 M31 elements.
///
/// This amounts to 4 butterfly layers, each with 16 butterflies.
/// Each of the vectors represents natural ordered polynomial coefficeint.
/// Each value in a vectors is in unreduced form: [0, P] including P.
/// Takes 4 twiddle arrays, one for each layer, holding the double of the corresponding twiddle.
/// The first layer (higher bit of the index) takes 2 twiddles.
/// The second layer takes 4 twiddles.
/// etc.
pub fn vecwise_butterflies(
mut val0: PackedBaseField,
mut val1: PackedBaseField,
twiddle1_dbl: [u32; 8],
twiddle2_dbl: [u32; 4],
twiddle3_dbl: [u32; 2],
) -> (PackedBaseField, PackedBaseField) {
// TODO(spapini): Compute twiddle0 from twiddle1.
// TODO(spapini): The permute can be fused with the _mm512_srli_epi64 inside the butterfly.
// The implementation is the exact reverse of vecwise_ibutterflies().
// See the comments in its body for more info.
let t = simd_swizzle!(
u32x2::from(twiddle3_dbl),
[0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
);
(val0, val1) = val0.interleave(val1);
(val0, val1) = simd_butterfly(val0, val1, t);
let t = simd_swizzle!(
u32x4::from(twiddle2_dbl),
[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
);
(val0, val1) = val0.interleave(val1);
(val0, val1) = simd_butterfly(val0, val1, t);
let (t0, t1) = compute_first_twiddles(u32x8::from(twiddle1_dbl));
(val0, val1) = val0.interleave(val1);
(val0, val1) = simd_butterfly(val0, val1, t1);
(val0, val1) = val0.interleave(val1);
(val0, val1) = simd_butterfly(val0, val1, t0);
val0.interleave(val1)
}
/// Returns the line twiddles (x points) for an fft on a coset.
pub fn get_twiddle_dbls(mut coset: Coset) -> Vec<Vec<u32>> {
let mut res = vec![];
for _ in 0..coset.log_size() {
res.push(
coset
.iter()
.take(coset.size() / 2)
.map(|p| p.x.0 * 2)
.collect_vec(),
);
bit_reverse(res.last_mut().unwrap());
coset = coset.double();
}
res
}
/// Applies 3 butterfly layers on 8 vectors of 16 M31 elements.
///
/// Vectorized over the 16 elements of the vectors.
/// Used for radix-8 ifft.
/// Each butterfly layer, has 3 SIMD butterflies.
/// Total of 12 SIMD butterflies.
///
/// # Arguments
///
/// - `src`: A pointer to the values to transform, aligned to 64 bytes.
/// - `dst`: A pointer to the destination array, aligned to 64 bytes.
/// - `offset`: The offset of the first value in the array.
/// - `log_step`: The log of the distance in the array, in M31 elements, between each pair of values
/// that need to be transformed. For layer i this is i - 4.
/// - `twiddles_dbl0/1/2`: The double of the twiddles for the 3 layers of butterflies. Each layer
/// has 4/2/1 twiddles.
///
/// # Safety
///
/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`].
pub unsafe fn fft3(
src: *const u32,
dst: *mut u32,
offset: usize,
log_step: usize,
twiddles_dbl0: [u32; 4],
twiddles_dbl1: [u32; 2],
twiddles_dbl2: [u32; 1],
) {
// Load the 8 SIMD vectors from the array.
let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step)));
let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step)));
let mut val2 = PackedBaseField::load(src.add(offset + (2 << log_step)));
let mut val3 = PackedBaseField::load(src.add(offset + (3 << log_step)));
let mut val4 = PackedBaseField::load(src.add(offset + (4 << log_step)));
let mut val5 = PackedBaseField::load(src.add(offset + (5 << log_step)));
let mut val6 = PackedBaseField::load(src.add(offset + (6 << log_step)));
let mut val7 = PackedBaseField::load(src.add(offset + (7 << log_step)));
// Apply the third layer of butterflies.
(val0, val4) = simd_butterfly(val0, val4, u32x16::splat(twiddles_dbl2[0]));
(val1, val5) = simd_butterfly(val1, val5, u32x16::splat(twiddles_dbl2[0]));
(val2, val6) = simd_butterfly(val2, val6, u32x16::splat(twiddles_dbl2[0]));
(val3, val7) = simd_butterfly(val3, val7, u32x16::splat(twiddles_dbl2[0]));
// Apply the second layer of butterflies.
(val0, val2) = simd_butterfly(val0, val2, u32x16::splat(twiddles_dbl1[0]));
(val1, val3) = simd_butterfly(val1, val3, u32x16::splat(twiddles_dbl1[0]));
(val4, val6) = simd_butterfly(val4, val6, u32x16::splat(twiddles_dbl1[1]));
(val5, val7) = simd_butterfly(val5, val7, u32x16::splat(twiddles_dbl1[1]));
// Apply the first layer of butterflies.
(val0, val1) = simd_butterfly(val0, val1, u32x16::splat(twiddles_dbl0[0]));
(val2, val3) = simd_butterfly(val2, val3, u32x16::splat(twiddles_dbl0[1]));
(val4, val5) = simd_butterfly(val4, val5, u32x16::splat(twiddles_dbl0[2]));
(val6, val7) = simd_butterfly(val6, val7, u32x16::splat(twiddles_dbl0[3]));
// Store the 8 SIMD vectors back to the array.
val0.store(dst.add(offset + (0 << log_step)));
val1.store(dst.add(offset + (1 << log_step)));
val2.store(dst.add(offset + (2 << log_step)));
val3.store(dst.add(offset + (3 << log_step)));
val4.store(dst.add(offset + (4 << log_step)));
val5.store(dst.add(offset + (5 << log_step)));
val6.store(dst.add(offset + (6 << log_step)));
val7.store(dst.add(offset + (7 << log_step)));
}
/// Applies 2 butterfly layers on 4 vectors of 16 M31 elements.
///
/// Vectorized over the 16 elements of the vectors.
/// Used for radix-4 fft.
/// Each butterfly layer, has 2 SIMD butterflies.
/// Total of 4 SIMD butterflies.
///
/// # Arguments
///
/// - `src`: A pointer to the values to transform, aligned to 64 bytes.
/// - `dst`: A pointer to the destination array, aligned to 64 bytes.
/// - `offset`: The offset of the first value in the array.
/// - `log_step`: The log of the distance in the array, in M31 elements, between each pair of values
/// that need to be transformed. For layer i this is i - 4.
/// - `twiddles_dbl0/1`: The double of the twiddles for the 2 layers of butterflies. Each layer has
/// 2/1 twiddles.
///
/// # Safety
///
/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`].
pub unsafe fn fft2(
src: *const u32,
dst: *mut u32,
offset: usize,
log_step: usize,
twiddles_dbl0: [u32; 2],
twiddles_dbl1: [u32; 1],
) {
// Load the 4 SIMD vectors from the array.
let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step)));
let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step)));
let mut val2 = PackedBaseField::load(src.add(offset + (2 << log_step)));
let mut val3 = PackedBaseField::load(src.add(offset + (3 << log_step)));
// Apply the second layer of butterflies.
(val0, val2) = simd_butterfly(val0, val2, u32x16::splat(twiddles_dbl1[0]));
(val1, val3) = simd_butterfly(val1, val3, u32x16::splat(twiddles_dbl1[0]));
// Apply the first layer of butterflies.
(val0, val1) = simd_butterfly(val0, val1, u32x16::splat(twiddles_dbl0[0]));
(val2, val3) = simd_butterfly(val2, val3, u32x16::splat(twiddles_dbl0[1]));
// Store the 4 SIMD vectors back to the array.
val0.store(dst.add(offset + (0 << log_step)));
val1.store(dst.add(offset + (1 << log_step)));
val2.store(dst.add(offset + (2 << log_step)));
val3.store(dst.add(offset + (3 << log_step)));
}
/// Applies 1 butterfly layers on 2 vectors of 16 M31 elements.
///
/// Vectorized over the 16 elements of the vectors.
///
/// # Arguments
///
/// - `src`: A pointer to the values to transform, aligned to 64 bytes.
/// - `dst`: A pointer to the destination array, aligned to 64 bytes.
/// - `offset`: The offset of the first value in the array.
/// - `log_step`: The log of the distance in the array, in M31 elements, between each pair of values
/// that need to be transformed. For layer i this is i - 4.
/// - `twiddles_dbl0`: The double of the twiddles for the butterfly layer.
///
/// # Safety
///
/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`].
pub unsafe fn fft1(
src: *const u32,
dst: *mut u32,
offset: usize,
log_step: usize,
twiddles_dbl0: [u32; 1],
) {
// Load the 2 SIMD vectors from the array.
let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step)));
let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step)));
(val0, val1) = simd_butterfly(val0, val1, u32x16::splat(twiddles_dbl0[0]));
// Store the 2 SIMD vectors back to the array.
val0.store(dst.add(offset + (0 << log_step)));
val1.store(dst.add(offset + (1 << log_step)));
}
#[cfg(test)]
mod tests {
use std::mem::transmute;
use std::simd::u32x16;
use itertools::Itertools;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use super::{
fft, fft3, fft_lower_with_vecwise, get_twiddle_dbls, simd_butterfly, vecwise_butterflies,
};
use crate::core::backend::cpu::CpuCirclePoly;
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::simd::fft::{transpose_vecs, CACHED_FFT_LOG_SIZE};
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES, N_LANES};
use crate::core::backend::Column;
use crate::core::fft::butterfly as ground_truth_butterfly;
use crate::core::fields::m31::BaseField;
use crate::core::poly::circle::{CanonicCoset, CircleDomain};
#[test]
fn test_butterfly() {
let mut rng = SmallRng::seed_from_u64(0);
let mut v0: [BaseField; N_LANES] = rng.gen();
let mut v1: [BaseField; N_LANES] = rng.gen();
let twiddle: [BaseField; N_LANES] = rng.gen();
let twiddle_dbl = twiddle.map(|v| v.0 * 2);
let (r0, r1) = simd_butterfly(v0.into(), v1.into(), twiddle_dbl.into());
let r0 = r0.to_array();
let r1 = r1.to_array();
for i in 0..N_LANES {
ground_truth_butterfly(&mut v0[i], &mut v1[i], twiddle[i]);
assert_eq!((v0[i], v1[i]), (r0[i], r1[i]), "mismatch at i={i}");
}
}
#[test]
fn test_fft3() {
let mut rng = SmallRng::seed_from_u64(0);
let values = rng.gen::<[BaseField; 8]>().map(PackedBaseField::broadcast);
let twiddles0: [BaseField; 4] = rng.gen();
let twiddles1: [BaseField; 2] = rng.gen();
let twiddles2: [BaseField; 1] = rng.gen();
let twiddles0_dbl = twiddles0.map(|v| v.0 * 2);
let twiddles1_dbl = twiddles1.map(|v| v.0 * 2);
let twiddles2_dbl = twiddles2.map(|v| v.0 * 2);
let mut res = values;
unsafe {
fft3(
transmute(res.as_ptr()),
transmute(res.as_mut_ptr()),
0,
LOG_N_LANES as usize,
twiddles0_dbl,
twiddles1_dbl,
twiddles2_dbl,
)
};
let mut expected = values.map(|v| v.to_array()[0]);
for i in 0..8 {
let j = i ^ 4;
if i > j {
continue;
}
let (mut v0, mut v1) = (expected[i], expected[j]);
ground_truth_butterfly(&mut v0, &mut v1, twiddles2[0]);
(expected[i], expected[j]) = (v0, v1);
}
for i in 0..8 {
let j = i ^ 2;
if i > j {
continue;
}
let (mut v0, mut v1) = (expected[i], expected[j]);
ground_truth_butterfly(&mut v0, &mut v1, twiddles1[i / 4]);
(expected[i], expected[j]) = (v0, v1);
}
for i in 0..8 {
let j = i ^ 1;
if i > j {
continue;
}
let (mut v0, mut v1) = (expected[i], expected[j]);
ground_truth_butterfly(&mut v0, &mut v1, twiddles0[i / 2]);
(expected[i], expected[j]) = (v0, v1);
}
for i in 0..8 {
assert_eq!(
res[i].to_array(),
[expected[i]; N_LANES],
"mismatch at i={i}"
);
}
}
#[test]
fn test_vecwise_butterflies() {
let domain = CanonicCoset::new(5).circle_domain();
let twiddle_dbls = get_twiddle_dbls(domain.half_coset);
assert_eq!(twiddle_dbls.len(), 4);
let mut rng = SmallRng::seed_from_u64(0);
let values: [[BaseField; 16]; 2] = rng.gen();
let res = {
let (val0, val1) = simd_butterfly(
values[0].into(),
values[1].into(),
u32x16::splat(twiddle_dbls[3][0]),
);
let (val0, val1) = vecwise_butterflies(
val0,
val1,
twiddle_dbls[0].clone().try_into().unwrap(),
twiddle_dbls[1].clone().try_into().unwrap(),
twiddle_dbls[2].clone().try_into().unwrap(),
);
[val0.to_array(), val1.to_array()].concat()
};
assert_eq!(res, ground_truth_fft(domain, values.flatten()));
}
#[test]
fn test_fft_lower() {
for log_size in 5..12 {
let domain = CanonicCoset::new(log_size).circle_domain();
let mut rng = SmallRng::seed_from_u64(0);
let values = (0..domain.size()).map(|_| rng.gen()).collect_vec();
let twiddle_dbls = get_twiddle_dbls(domain.half_coset);
let mut res = values.iter().copied().collect::<BaseColumn>();
unsafe {
fft_lower_with_vecwise(
transmute(res.data.as_ptr()),
transmute(res.data.as_mut_ptr()),
&twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(),
log_size as usize,
log_size as usize,
)
}
assert_eq!(res.to_cpu(), ground_truth_fft(domain, &values));
}
}
#[test]
fn test_fft_full() {
for log_size in CACHED_FFT_LOG_SIZE + 1..CACHED_FFT_LOG_SIZE + 3 {
let domain = CanonicCoset::new(log_size).circle_domain();
let mut rng = SmallRng::seed_from_u64(0);
let values = (0..domain.size()).map(|_| rng.gen()).collect_vec();
let twiddle_dbls = get_twiddle_dbls(domain.half_coset);
let mut res = values.iter().copied().collect::<BaseColumn>();
unsafe {
transpose_vecs(transmute(res.data.as_mut_ptr()), log_size as usize - 4);
fft(
transmute(res.data.as_ptr()),
transmute(res.data.as_mut_ptr()),
&twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(),
log_size as usize,
);
}
assert_eq!(res.to_cpu(), ground_truth_fft(domain, &values));
}
}
fn ground_truth_fft(domain: CircleDomain, values: &[BaseField]) -> Vec<BaseField> {
let poly = CpuCirclePoly::new(values.to_vec());
poly.evaluate(domain).values
}
}

View File

@ -0,0 +1,261 @@
use std::array;
use std::simd::u32x8;
use num_traits::Zero;
use super::m31::{PackedBaseField, LOG_N_LANES, N_LANES};
use super::SimdBackend;
use crate::core::backend::simd::fft::compute_first_twiddles;
use crate::core::backend::simd::fft::ifft::simd_ibutterfly;
use crate::core::backend::simd::qm31::PackedSecureField;
use crate::core::backend::Column;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumnByCoords;
use crate::core::fri::{self, FriOps};
use crate::core::poly::circle::SecureEvaluation;
use crate::core::poly::line::LineEvaluation;
use crate::core::poly::twiddles::TwiddleTree;
use crate::core::poly::utils::domain_line_twiddles_from_tree;
use crate::core::poly::BitReversedOrder;
impl FriOps for SimdBackend {
fn fold_line(
eval: &LineEvaluation<Self>,
alpha: SecureField,
twiddles: &TwiddleTree<Self>,
) -> LineEvaluation<Self> {
let log_size = eval.len().ilog2();
if log_size <= LOG_N_LANES {
let eval = fri::fold_line(&eval.to_cpu(), alpha);
return LineEvaluation::new(eval.domain(), eval.values.into_iter().collect());
}
let domain = eval.domain();
let itwiddles = domain_line_twiddles_from_tree(domain, &twiddles.itwiddles)[0];
let mut folded_values = SecureColumnByCoords::<Self>::zeros(1 << (log_size - 1));
for vec_index in 0..(1 << (log_size - 1 - LOG_N_LANES)) {
let value = unsafe {
let twiddle_dbl: [u32; 16] =
array::from_fn(|i| *itwiddles.get_unchecked(vec_index * 16 + i));
let val0 = eval.values.packed_at(vec_index * 2).into_packed_m31s();
let val1 = eval.values.packed_at(vec_index * 2 + 1).into_packed_m31s();
let pairs: [_; 4] = array::from_fn(|i| {
let (a, b) = val0[i].deinterleave(val1[i]);
simd_ibutterfly(a, b, std::mem::transmute(twiddle_dbl))
});
let val0 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].0));
let val1 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].1));
val0 + PackedSecureField::broadcast(alpha) * val1
};
unsafe { folded_values.set_packed(vec_index, value) };
}
LineEvaluation::new(domain.double(), folded_values)
}
fn fold_circle_into_line(
dst: &mut LineEvaluation<Self>,
src: &SecureEvaluation<Self, BitReversedOrder>,
alpha: SecureField,
twiddles: &TwiddleTree<Self>,
) {
let log_size = src.len().ilog2();
assert!(log_size > LOG_N_LANES, "Evaluation too small");
let domain = src.domain;
let alpha_sq = alpha * alpha;
let itwiddles = domain_line_twiddles_from_tree(domain, &twiddles.itwiddles)[0];
for vec_index in 0..(1 << (log_size - 1 - LOG_N_LANES)) {
let value = unsafe {
// The 16 twiddles of the circle domain can be derived from the 8 twiddles of the
// next line domain. See `compute_first_twiddles()`.
let twiddle_dbl = u32x8::from_array(array::from_fn(|i| {
*itwiddles.get_unchecked(vec_index * 8 + i)
}));
let (t0, _) = compute_first_twiddles(twiddle_dbl);
let val0 = src.values.packed_at(vec_index * 2).into_packed_m31s();
let val1 = src.values.packed_at(vec_index * 2 + 1).into_packed_m31s();
let pairs: [_; 4] = array::from_fn(|i| {
let (a, b) = val0[i].deinterleave(val1[i]);
simd_ibutterfly(a, b, t0)
});
let val0 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].0));
let val1 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].1));
val0 + PackedSecureField::broadcast(alpha) * val1
};
unsafe {
dst.values.set_packed(
vec_index,
dst.values.packed_at(vec_index) * PackedSecureField::broadcast(alpha_sq)
+ value,
)
};
}
}
fn decompose(
eval: &SecureEvaluation<Self, BitReversedOrder>,
) -> (SecureEvaluation<Self, BitReversedOrder>, SecureField) {
let lambda = decomposition_coefficient(eval);
let broadcasted_lambda = PackedSecureField::broadcast(lambda);
let mut g_values = SecureColumnByCoords::<Self>::zeros(eval.len());
let range = eval.len().div_ceil(N_LANES);
let half_range = range / 2;
for i in 0..half_range {
let val = unsafe { eval.packed_at(i) } - broadcasted_lambda;
unsafe { g_values.set_packed(i, val) }
}
for i in half_range..range {
let val = unsafe { eval.packed_at(i) } + broadcasted_lambda;
unsafe { g_values.set_packed(i, val) }
}
let g = SecureEvaluation::new(eval.domain, g_values);
(g, lambda)
}
}
/// See [`decomposition_coefficient`].
///
/// [`decomposition_coefficient`]: crate::core::backend::cpu::CpuBackend::decomposition_coefficient
fn decomposition_coefficient(
eval: &SecureEvaluation<SimdBackend, BitReversedOrder>,
) -> SecureField {
let cols = &eval.values.columns;
let [mut x_sum, mut y_sum, mut z_sum, mut w_sum] = [PackedBaseField::zero(); 4];
let range = cols[0].len() / N_LANES;
let (half_a, half_b) = (range / 2, range);
for i in 0..half_a {
x_sum += cols[0].data[i];
y_sum += cols[1].data[i];
z_sum += cols[2].data[i];
w_sum += cols[3].data[i];
}
for i in half_a..half_b {
x_sum -= cols[0].data[i];
y_sum -= cols[1].data[i];
z_sum -= cols[2].data[i];
w_sum -= cols[3].data[i];
}
let x = x_sum.pointwise_sum();
let y = y_sum.pointwise_sum();
let z = z_sum.pointwise_sum();
let w = w_sum.pointwise_sum();
SecureField::from_m31(x, y, z, w) / BaseField::from_u32_unchecked(1 << eval.domain.log_size())
}
#[cfg(test)]
mod tests {
use itertools::Itertools;
use num_traits::One;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Column, CpuBackend};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumnByCoords;
use crate::core::fri::FriOps;
use crate::core::poly::circle::{CanonicCoset, CirclePoly, PolyOps, SecureEvaluation};
use crate::core::poly::line::{LineDomain, LineEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::qm31;
#[test]
fn test_fold_line() {
const LOG_SIZE: u32 = 7;
let mut rng = SmallRng::seed_from_u64(0);
let values = (0..1 << LOG_SIZE).map(|_| rng.gen()).collect_vec();
let alpha = qm31!(1, 3, 5, 7);
let domain = LineDomain::new(CanonicCoset::new(LOG_SIZE + 1).half_coset());
let cpu_fold = CpuBackend::fold_line(
&LineEvaluation::new(domain, values.iter().copied().collect()),
alpha,
&CpuBackend::precompute_twiddles(domain.coset()),
);
let avx_fold = SimdBackend::fold_line(
&LineEvaluation::new(domain, values.iter().copied().collect()),
alpha,
&SimdBackend::precompute_twiddles(domain.coset()),
);
assert_eq!(cpu_fold.values.to_vec(), avx_fold.values.to_vec());
}
#[test]
fn test_fold_circle_into_line() {
const LOG_SIZE: u32 = 7;
let values: Vec<SecureField> = (0..(1 << LOG_SIZE))
.map(|i| qm31!(4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3))
.collect();
let alpha = qm31!(1, 3, 5, 7);
let circle_domain = CanonicCoset::new(LOG_SIZE).circle_domain();
let line_domain = LineDomain::new(circle_domain.half_coset);
let mut cpu_fold = LineEvaluation::new(
line_domain,
SecureColumnByCoords::zeros(1 << (LOG_SIZE - 1)),
);
CpuBackend::fold_circle_into_line(
&mut cpu_fold,
&SecureEvaluation::new(circle_domain, values.iter().copied().collect()),
alpha,
&CpuBackend::precompute_twiddles(line_domain.coset()),
);
let mut simd_fold = LineEvaluation::new(
line_domain,
SecureColumnByCoords::zeros(1 << (LOG_SIZE - 1)),
);
SimdBackend::fold_circle_into_line(
&mut simd_fold,
&SecureEvaluation::new(circle_domain, values.iter().copied().collect()),
alpha,
&SimdBackend::precompute_twiddles(line_domain.coset()),
);
assert_eq!(cpu_fold.values.to_vec(), simd_fold.values.to_vec());
}
#[test]
fn decomposition_test() {
const DOMAIN_LOG_SIZE: u32 = 5;
const DOMAIN_LOG_HALF_SIZE: u32 = DOMAIN_LOG_SIZE - 1;
let s = CanonicCoset::new(DOMAIN_LOG_SIZE);
let domain = s.circle_domain();
let mut coeffs = BaseColumn::zeros(1 << DOMAIN_LOG_SIZE);
// Polynomial is out of FFT space.
coeffs.as_mut_slice()[1 << DOMAIN_LOG_HALF_SIZE] = BaseField::one();
let poly = CirclePoly::<SimdBackend>::new(coeffs);
let values = poly.evaluate(domain);
let avx_column = SecureColumnByCoords::<SimdBackend> {
columns: [
values.values.clone(),
values.values.clone(),
values.values.clone(),
values.values.clone(),
],
};
let avx_eval = SecureEvaluation::new(domain, avx_column.clone());
let cpu_eval =
SecureEvaluation::<CpuBackend, BitReversedOrder>::new(domain, avx_eval.to_cpu());
let (cpu_g, cpu_lambda) = CpuBackend::decompose(&cpu_eval);
let (avx_g, avx_lambda) = SimdBackend::decompose(&avx_eval);
assert_eq!(avx_lambda, cpu_lambda);
for i in 0..1 << DOMAIN_LOG_SIZE {
assert_eq!(avx_g.values.at(i), cpu_g.values.at(i));
}
}
}

View File

@ -0,0 +1,95 @@
use std::simd::cmp::SimdPartialOrd;
use std::simd::num::SimdUint;
use std::simd::u32x16;
use bytemuck::cast_slice;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
use super::blake2s::compress16;
use super::SimdBackend;
use crate::core::backend::simd::m31::N_LANES;
use crate::core::channel::Blake2sChannel;
#[cfg(not(target_arch = "wasm32"))]
use crate::core::channel::{Channel, Poseidon252Channel, PoseidonBLSChannel};
use crate::core::proof_of_work::GrindOps;
// Note: GRIND_LOW_BITS is a cap on how much extra time we need to wait for all threads to finish.
const GRIND_LOW_BITS: u32 = 20;
const GRIND_HI_BITS: u32 = 64 - GRIND_LOW_BITS;
impl GrindOps<Blake2sChannel> for SimdBackend {
fn grind(channel: &Blake2sChannel, pow_bits: u32) -> u64 {
// TODO(spapini): support more than 32 bits.
assert!(pow_bits <= 32, "pow_bits > 32 is not supported");
let digest = channel.digest();
let digest: &[u32] = cast_slice(&digest.0[..]);
#[cfg(not(feature = "parallel"))]
let res = (0..=(1 << GRIND_HI_BITS))
.find_map(|hi| grind_blake(digest, hi, pow_bits))
.expect("Grind failed to find a solution.");
#[cfg(feature = "parallel")]
let res = (0..=(1 << GRIND_HI_BITS))
.into_par_iter()
.find_map_any(|hi| grind_blake(digest, hi, pow_bits))
.expect("Grind failed to find a solution.");
res
}
}
fn grind_blake(digest: &[u32], hi: u64, pow_bits: u32) -> Option<u64> {
let zero: u32x16 = u32x16::default();
let pow_bits = u32x16::splat(pow_bits);
let state: [u32x16; 8] = std::array::from_fn(|i| u32x16::splat(digest[i]));
let mut attempt = [zero; 16];
attempt[0] = u32x16::splat((hi << GRIND_LOW_BITS) as u32);
attempt[0] += u32x16::from(std::array::from_fn(|i| i as u32));
attempt[1] = u32x16::splat((hi >> (32 - GRIND_LOW_BITS)) as u32);
for low in (0..(1 << GRIND_LOW_BITS)).step_by(N_LANES) {
let res = compress16(state, attempt, zero, zero, zero, zero);
let success_mask = res[0].trailing_zeros().simd_ge(pow_bits);
if success_mask.any() {
let i = success_mask.to_array().iter().position(|&x| x).unwrap();
return Some((hi << GRIND_LOW_BITS) + low as u64 + i as u64);
}
attempt[0] += u32x16::splat(N_LANES as u32);
}
None
}
// TODO(spapini): This is a naive implementation. Optimize it.
#[cfg(not(target_arch = "wasm32"))]
impl GrindOps<Poseidon252Channel> for SimdBackend {
fn grind(channel: &Poseidon252Channel, pow_bits: u32) -> u64 {
let mut nonce = 0;
loop {
let mut channel = channel.clone();
channel.mix_nonce(nonce);
if channel.trailing_zeros() >= pow_bits {
return nonce;
}
nonce += 1;
}
}
}
// TODO(spapini): This is a naive implementation. Optimize it.
#[cfg(not(target_arch = "wasm32"))]
impl GrindOps<PoseidonBLSChannel> for SimdBackend {
fn grind(channel: &PoseidonBLSChannel, pow_bits: u32) -> u64 {
let mut nonce = 0;
loop {
let mut channel = channel.clone();
channel.mix_nonce(nonce);
if channel.trailing_zeros() >= pow_bits {
return nonce;
}
nonce += 1;
}
}
}

View File

@ -0,0 +1,684 @@
use std::iter::zip;
use num_traits::Zero;
use crate::core::backend::cpu::lookups::gkr::gen_eq_evals as cpu_gen_eq_evals;
use crate::core::backend::simd::column::SecureColumn;
use crate::core::backend::simd::m31::{LOG_N_LANES, N_LANES};
use crate::core::backend::simd::qm31::PackedSecureField;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Column, CpuBackend};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::gkr_prover::{
correct_sum_as_poly_in_first_variable, EqEvals, GkrMultivariatePolyOracle, GkrOps, Layer,
};
use crate::core::lookups::mle::Mle;
use crate::core::lookups::sumcheck::MultivariatePolyOracle;
use crate::core::lookups::utils::{Fraction, Reciprocal, UnivariatePoly};
impl GkrOps for SimdBackend {
#[allow(clippy::uninit_vec)]
fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Mle<Self, SecureField> {
if y.len() < LOG_N_LANES as usize {
return Mle::new(cpu_gen_eq_evals(y, v).into_iter().collect());
}
// Start DP with CPU backend to avoid dealing with instances smaller than a SIMD vector.
let (y_last_chunk, y_rem) = y.split_last_chunk::<{ LOG_N_LANES as usize }>().unwrap();
let initial = SecureColumn::from_iter(cpu_gen_eq_evals(y_last_chunk, v));
assert_eq!(initial.len(), N_LANES);
let packed_len = 1 << y_rem.len();
let mut data = initial.data;
data.reserve(packed_len - data.len());
unsafe { data.set_len(packed_len) };
for (i, &y_j) in y_rem.iter().rev().enumerate() {
let packed_y_j = PackedSecureField::broadcast(y_j);
let (lhs_evals, rhs_evals) = data.split_at_mut(1 << i);
for (lhs, rhs) in zip(lhs_evals, rhs_evals) {
// Equivalent to:
// `rhs = eq(1, y_j) * lhs`,
// `lhs = eq(0, y_j) * lhs`
*rhs = *lhs * packed_y_j;
*lhs -= *rhs;
}
}
let length = packed_len * N_LANES;
Mle::new(SecureColumn { data, length })
}
fn next_layer(layer: &Layer<Self>) -> Layer<Self> {
// Offload to CPU backend to avoid dealing with instances smaller than a SIMD vector.
if layer.n_variables() as u32 <= LOG_N_LANES {
return into_simd_layer(layer.to_cpu().next_layer().unwrap());
}
match layer {
Layer::GrandProduct(col) => next_grand_product_layer(col),
Layer::LogUpGeneric {
numerators,
denominators,
} => next_logup_generic_layer(numerators, denominators),
Layer::LogUpMultiplicities {
numerators,
denominators,
} => next_logup_multiplicities_layer(numerators, denominators),
Layer::LogUpSingles { denominators } => next_logup_singles_layer(denominators),
}
}
fn sum_as_poly_in_first_variable(
h: &GkrMultivariatePolyOracle<'_, Self>,
claim: SecureField,
) -> UnivariatePoly<SecureField> {
let n_variables = h.n_variables();
let n_terms = 1 << n_variables.saturating_sub(1);
let eq_evals = h.eq_evals.as_ref();
// Vector used to generate evaluations of `eq(x, y)` for `x` in the boolean hypercube.
let y = eq_evals.y();
// Offload to CPU backend to avoid dealing with instances smaller than a SIMD vector.
if n_terms < N_LANES {
return h.to_cpu().sum_as_poly_in_first_variable(claim);
}
let n_packed_terms = n_terms / N_LANES;
let packed_lambda = PackedSecureField::broadcast(h.lambda);
let (mut eval_at_0, mut eval_at_2) = match &h.input_layer {
Layer::GrandProduct(col) => eval_grand_product_sum(eq_evals, col, n_packed_terms),
Layer::LogUpGeneric {
numerators,
denominators,
} => eval_logup_generic_sum(
eq_evals,
numerators,
denominators,
n_packed_terms,
packed_lambda,
),
Layer::LogUpMultiplicities {
numerators,
denominators,
} => eval_logup_multiplicities_sum(
eq_evals,
numerators,
denominators,
n_packed_terms,
packed_lambda,
),
Layer::LogUpSingles { denominators } => {
eval_logup_singles_sum(eq_evals, denominators, n_packed_terms, packed_lambda)
}
};
eval_at_0 *= h.eq_fixed_var_correction;
eval_at_2 *= h.eq_fixed_var_correction;
correct_sum_as_poly_in_first_variable(eval_at_0, eval_at_2, claim, y, n_variables)
}
}
/// Generates the next GKR layer for Grand Product.
///
/// Assumption: `len(layer) > N_LANES`.
fn next_grand_product_layer(layer: &Mle<SimdBackend, SecureField>) -> Layer<SimdBackend> {
assert!(layer.len() > N_LANES);
let next_layer_len = layer.len() / 2;
let data = layer
.data
.array_chunks()
.map(|&[a, b]| {
let (evens, odds) = a.deinterleave(b);
evens * odds
})
.collect();
Layer::GrandProduct(Mle::new(SecureColumn {
data,
length: next_layer_len,
}))
}
/// Generates the next GKR layer for LogUp.
///
/// Assumption: `len(denominators) > N_LANES`.
fn next_logup_generic_layer(
numerators: &Mle<SimdBackend, SecureField>,
denominators: &Mle<SimdBackend, SecureField>,
) -> Layer<SimdBackend> {
assert!(denominators.len() > N_LANES);
assert_eq!(numerators.len(), denominators.len());
let next_layer_len = denominators.len() / 2;
let next_layer_packed_len = next_layer_len / N_LANES;
let mut next_numerators = Vec::with_capacity(next_layer_packed_len);
let mut next_denominators = Vec::with_capacity(next_layer_packed_len);
for i in 0..next_layer_packed_len {
let (n_even, n_odd) = numerators.data[i * 2].deinterleave(numerators.data[i * 2 + 1]);
let (d_even, d_odd) = denominators.data[i * 2].deinterleave(denominators.data[i * 2 + 1]);
let Fraction {
numerator,
denominator,
} = Fraction::new(n_even, d_even) + Fraction::new(n_odd, d_odd);
next_numerators.push(numerator);
next_denominators.push(denominator);
}
let next_numerators = SecureColumn {
data: next_numerators,
length: next_layer_len,
};
let next_denominators = SecureColumn {
data: next_denominators,
length: next_layer_len,
};
Layer::LogUpGeneric {
numerators: Mle::new(next_numerators),
denominators: Mle::new(next_denominators),
}
}
/// Generates the next GKR layer for LogUp.
///
/// Assumption: `len(denominators) > N_LANES`.
// TODO(andrew): Code duplication of `next_logup_generic_layer`. Consider unifying these.
fn next_logup_multiplicities_layer(
numerators: &Mle<SimdBackend, BaseField>,
denominators: &Mle<SimdBackend, SecureField>,
) -> Layer<SimdBackend> {
assert!(denominators.len() > N_LANES);
assert_eq!(numerators.len(), denominators.len());
let next_layer_len = denominators.len() / 2;
let next_layer_packed_len = next_layer_len / N_LANES;
let mut next_numerators = Vec::with_capacity(next_layer_packed_len);
let mut next_denominators = Vec::with_capacity(next_layer_packed_len);
for i in 0..next_layer_packed_len {
let (n_even, n_odd) = numerators.data[i * 2].deinterleave(numerators.data[i * 2 + 1]);
let (d_even, d_odd) = denominators.data[i * 2].deinterleave(denominators.data[i * 2 + 1]);
let Fraction {
numerator,
denominator,
} = Fraction::new(n_even, d_even) + Fraction::new(n_odd, d_odd);
next_numerators.push(numerator);
next_denominators.push(denominator);
}
let next_numerators = SecureColumn {
data: next_numerators,
length: next_layer_len,
};
let next_denominators = SecureColumn {
data: next_denominators,
length: next_layer_len,
};
Layer::LogUpGeneric {
numerators: Mle::new(next_numerators),
denominators: Mle::new(next_denominators),
}
}
/// Generates the next GKR layer for LogUp.
///
/// Assumption: `len(denominators) > N_LANES`.
fn next_logup_singles_layer(denominators: &Mle<SimdBackend, SecureField>) -> Layer<SimdBackend> {
assert!(denominators.len() > N_LANES);
let next_layer_len = denominators.len() / 2;
let next_layer_packed_len = next_layer_len / N_LANES;
let mut next_numerators = Vec::with_capacity(next_layer_packed_len);
let mut next_denominators = Vec::with_capacity(next_layer_packed_len);
for i in 0..next_layer_packed_len {
let (d_even, d_odd) = denominators.data[i * 2].deinterleave(denominators.data[i * 2 + 1]);
let Fraction {
numerator,
denominator,
} = Reciprocal::new(d_even) + Reciprocal::new(d_odd);
next_numerators.push(numerator);
next_denominators.push(denominator);
}
let next_numerators = SecureColumn {
data: next_numerators,
length: next_layer_len,
};
let next_denominators = SecureColumn {
data: next_denominators,
length: next_layer_len,
};
Layer::LogUpGeneric {
numerators: Mle::new(next_numerators),
denominators: Mle::new(next_denominators),
}
}
/// Evaluates `sum_x eq(({0}^|r|, 0, x), y) * inp(r, t, x, 0) * inp(r, t, x, 1)` at `t=0` and `t=2`.
///
/// Output of the form: `(eval_at_0, eval_at_2)`.
fn eval_grand_product_sum(
eq_evals: &EqEvals<SimdBackend>,
col: &Mle<SimdBackend, SecureField>,
n_packed_terms: usize,
) -> (SecureField, SecureField) {
let mut packed_eval_at_0 = PackedSecureField::zero();
let mut packed_eval_at_2 = PackedSecureField::zero();
for i in 0..n_packed_terms {
// Input polynomial at points `(r, {0, 1, 2}, bits(i), v, {0, 1})`
// for all `v` in `{0, 1}^LOG_N_SIMD_LANES`.
let (inp_at_r0iv0, inp_at_r0iv1) = col.data[i * 2].deinterleave(col.data[i * 2 + 1]);
let (inp_at_r1iv0, inp_at_r1iv1) =
col.data[(n_packed_terms + i) * 2].deinterleave(col.data[(n_packed_terms + i) * 2 + 1]);
// Note `inp(r, t, x) = eq(t, 0) * inp(r, 0, x) + eq(t, 1) * inp(r, 1, x)`
// => `inp(r, 2, x) = 2 * inp(r, 1, x) - inp(r, 0, x)`
let inp_at_r2iv0 = inp_at_r1iv0.double() - inp_at_r0iv0;
let inp_at_r2iv1 = inp_at_r1iv1.double() - inp_at_r0iv1;
// Product polynomial `prod(x) = inp(x, 0) * inp(x, 1)` at points `(r, {0, 2}, bits(i), v)`.
// for all `v` in `{0, 1}^LOG_N_SIMD_LANES`.
let prod_at_r2iv = inp_at_r2iv0 * inp_at_r2iv1;
let prod_at_r0iv = inp_at_r0iv0 * inp_at_r0iv1;
let eq_eval_at_0iv = eq_evals.data[i];
packed_eval_at_0 += eq_eval_at_0iv * prod_at_r0iv;
packed_eval_at_2 += eq_eval_at_0iv * prod_at_r2iv;
}
(
packed_eval_at_0.pointwise_sum(),
packed_eval_at_2.pointwise_sum(),
)
}
fn eval_logup_generic_sum(
eq_evals: &EqEvals<SimdBackend>,
numerators: &Mle<SimdBackend, SecureField>,
denominators: &Mle<SimdBackend, SecureField>,
n_packed_terms: usize,
packed_lambda: PackedSecureField,
) -> (SecureField, SecureField) {
let mut packed_eval_at_0 = PackedSecureField::zero();
let mut packed_eval_at_2 = PackedSecureField::zero();
let inp_numer = &numerators.data;
let inp_denom = &denominators.data;
for i in 0..n_packed_terms {
// Input polynomials at points `(r, {0, 1, 2}, bits(i), v, {0, 1})`
// for all `v` in `{0, 1}^LOG_N_SIMD_LANES`.
let (inp_numer_at_r0iv0, inp_numer_at_r0iv1) =
inp_numer[i * 2].deinterleave(inp_numer[i * 2 + 1]);
let (inp_denom_at_r0iv0, inp_denom_at_r0iv1) =
inp_denom[i * 2].deinterleave(inp_denom[i * 2 + 1]);
let (inp_numer_at_r1iv0, inp_numer_at_r1iv1) = inp_numer[(n_packed_terms + i) * 2]
.deinterleave(inp_numer[(n_packed_terms + i) * 2 + 1]);
let (inp_denom_at_r1iv0, inp_denom_at_r1iv1) = inp_denom[(n_packed_terms + i) * 2]
.deinterleave(inp_denom[(n_packed_terms + i) * 2 + 1]);
// Note `inp_denom(r, t, x) = eq(t, 0) * inp_denom(r, 0, x) + eq(t, 1) * inp_denom(r, 1, x)`
// => `inp_denom(r, 2, x) = 2 * inp_denom(r, 1, x) - inp_denom(r, 0, x)`
let inp_numer_at_r2iv0 = inp_numer_at_r1iv0.double() - inp_numer_at_r0iv0;
let inp_numer_at_r2iv1 = inp_numer_at_r1iv1.double() - inp_numer_at_r0iv1;
let inp_denom_at_r2iv0 = inp_denom_at_r1iv0.double() - inp_denom_at_r0iv0;
let inp_denom_at_r2iv1 = inp_denom_at_r1iv1.double() - inp_denom_at_r0iv1;
// Fraction addition polynomials:
// - `numer(x) = inp_numer(x, 0) * inp_denom(x, 1) + inp_numer(x, 1) * inp_denom(x, 0)`
// - `denom(x) = inp_denom(x, 0) * inp_denom(x, 1)`.
// at points `(r, {0, 2}, bits(i), v)` for all `v` in `{0, 1}^LOG_N_SIMD_LANES`.
let Fraction {
numerator: numer_at_r0iv,
denominator: denom_at_r0iv,
} = Fraction::new(inp_numer_at_r0iv0, inp_denom_at_r0iv0)
+ Fraction::new(inp_numer_at_r0iv1, inp_denom_at_r0iv1);
let Fraction {
numerator: numer_at_r2iv,
denominator: denom_at_r2iv,
} = Fraction::new(inp_numer_at_r2iv0, inp_denom_at_r2iv0)
+ Fraction::new(inp_numer_at_r2iv1, inp_denom_at_r2iv1);
let eq_eval_at_0iv = eq_evals.data[i];
packed_eval_at_0 += eq_eval_at_0iv * (numer_at_r0iv + packed_lambda * denom_at_r0iv);
packed_eval_at_2 += eq_eval_at_0iv * (numer_at_r2iv + packed_lambda * denom_at_r2iv);
}
(
packed_eval_at_0.pointwise_sum(),
packed_eval_at_2.pointwise_sum(),
)
}
// TODO(andrew): Code duplication of `eval_logup_generic_sum`. Consider unifying these.
fn eval_logup_multiplicities_sum(
eq_evals: &EqEvals<SimdBackend>,
numerators: &Mle<SimdBackend, BaseField>,
denominators: &Mle<SimdBackend, SecureField>,
n_packed_terms: usize,
packed_lambda: PackedSecureField,
) -> (SecureField, SecureField) {
let mut packed_eval_at_0 = PackedSecureField::zero();
let mut packed_eval_at_2 = PackedSecureField::zero();
let inp_numer = &numerators.data;
let inp_denom = &denominators.data;
for i in 0..n_packed_terms {
// Input polynomials at points `(r, {0, 1, 2}, bits(i), v, {0, 1})`
// for all `v` in `{0, 1}^LOG_N_SIMD_LANES`.
let (inp_numer_at_r0iv0, inp_numer_at_r0iv1) =
inp_numer[i * 2].deinterleave(inp_numer[i * 2 + 1]);
let (inp_denom_at_r0iv0, inp_denom_at_r0iv1) =
inp_denom[i * 2].deinterleave(inp_denom[i * 2 + 1]);
let (inp_numer_at_r1iv0, inp_numer_at_r1iv1) = inp_numer[(n_packed_terms + i) * 2]
.deinterleave(inp_numer[(n_packed_terms + i) * 2 + 1]);
let (inp_denom_at_r1iv0, inp_denom_at_r1iv1) = inp_denom[(n_packed_terms + i) * 2]
.deinterleave(inp_denom[(n_packed_terms + i) * 2 + 1]);
// Note `inp_denom(r, t, x) = eq(t, 0) * inp_denom(r, 0, x) + eq(t, 1) * inp_denom(r, 1, x)`
// => `inp_denom(r, 2, x) = 2 * inp_denom(r, 1, x) - inp_denom(r, 0, x)`
let inp_numer_at_r2iv0 = inp_numer_at_r1iv0.double() - inp_numer_at_r0iv0;
let inp_numer_at_r2iv1 = inp_numer_at_r1iv1.double() - inp_numer_at_r0iv1;
let inp_denom_at_r2iv0 = inp_denom_at_r1iv0.double() - inp_denom_at_r0iv0;
let inp_denom_at_r2iv1 = inp_denom_at_r1iv1.double() - inp_denom_at_r0iv1;
// Fraction addition polynomials:
// - `numer(x) = inp_numer(x, 0) * inp_denom(x, 1) + inp_numer(x, 1) * inp_denom(x, 0)`
// - `denom(x) = inp_denom(x, 0) * inp_denom(x, 1)`.
// at points `(r, {0, 2}, bits(i), v)` for all `v` in `{0, 1}^LOG_N_SIMD_LANES`.
let Fraction {
numerator: numer_at_r0iv,
denominator: denom_at_r0iv,
} = Fraction::new(inp_numer_at_r0iv0, inp_denom_at_r0iv0)
+ Fraction::new(inp_numer_at_r0iv1, inp_denom_at_r0iv1);
let Fraction {
numerator: numer_at_r2iv,
denominator: denom_at_r2iv,
} = Fraction::new(inp_numer_at_r2iv0, inp_denom_at_r2iv0)
+ Fraction::new(inp_numer_at_r2iv1, inp_denom_at_r2iv1);
let eq_eval_at_0iv = eq_evals.data[i];
packed_eval_at_0 += eq_eval_at_0iv * (numer_at_r0iv + packed_lambda * denom_at_r0iv);
packed_eval_at_2 += eq_eval_at_0iv * (numer_at_r2iv + packed_lambda * denom_at_r2iv);
}
(
packed_eval_at_0.pointwise_sum(),
packed_eval_at_2.pointwise_sum(),
)
}
/// Evaluates `sum_x eq(({0}^|r|, 0, x), y) * (inp_denom(r, t, x, 1) + inp_denom(r, t, x, 0) +
/// lambda * inp_denom(r, t, x, 0) * inp_denom(r, t, x, 1))` at `t=0` and `t=2`.
///
/// Output of the form: `(eval_at_0, eval_at_2)`.
fn eval_logup_singles_sum(
eq_evals: &EqEvals<SimdBackend>,
denominators: &Mle<SimdBackend, SecureField>,
n_packed_terms: usize,
packed_lambda: PackedSecureField,
) -> (SecureField, SecureField) {
let mut packed_eval_at_0 = PackedSecureField::zero();
let mut packed_eval_at_2 = PackedSecureField::zero();
let inp_denom = &denominators.data;
for i in 0..n_packed_terms {
// Input polynomial at points `(r, {0, 1, 2}, bits(i), v, {0, 1})`
// for all `v` in `{0, 1}^LOG_N_SIMD_LANES`.
let (inp_denom_at_r0iv0, inp_denom_at_r0iv1) =
inp_denom[i * 2].deinterleave(inp_denom[i * 2 + 1]);
let (inp_denom_at_r1iv0, inp_denom_at_r1iv1) = inp_denom[(n_packed_terms + i) * 2]
.deinterleave(inp_denom[(n_packed_terms + i) * 2 + 1]);
// Note `inp_denom(r, t, x) = eq(t, 0) * inp_denom(r, 0, x) + eq(t, 1) * inp_denom(r, 1, x)`
// => `inp_denom(r, 2, x) = 2 * inp_denom(r, 1, x) - inp_denom(r, 0, x)`
let inp_denom_at_r2iv0 = inp_denom_at_r1iv0.double() - inp_denom_at_r0iv0;
let inp_denom_at_r2iv1 = inp_denom_at_r1iv1.double() - inp_denom_at_r0iv1;
// Fraction addition polynomials:
// - `numer(x) = inp_denom(x, 1) + inp_denom(x, 0)`
// - `denom(x) = inp_denom(x, 0) * inp_denom(x, 1)`.
// at points `(r, {0, 2}, bits(i), v)` for all `v` in `{0, 1}^LOG_N_SIMD_LANES`.
let Fraction {
numerator: numer_at_r0iv,
denominator: denom_at_r0iv,
} = Reciprocal::new(inp_denom_at_r0iv0) + Reciprocal::new(inp_denom_at_r0iv1);
let Fraction {
numerator: numer_at_r2iv,
denominator: denom_at_r2iv,
} = Reciprocal::new(inp_denom_at_r2iv0) + Reciprocal::new(inp_denom_at_r2iv1);
let eq_eval_at_0iv = eq_evals.data[i];
packed_eval_at_0 += eq_eval_at_0iv * (numer_at_r0iv + packed_lambda * denom_at_r0iv);
packed_eval_at_2 += eq_eval_at_0iv * (numer_at_r2iv + packed_lambda * denom_at_r2iv);
}
(
packed_eval_at_0.pointwise_sum(),
packed_eval_at_2.pointwise_sum(),
)
}
fn into_simd_layer(cpu_layer: Layer<CpuBackend>) -> Layer<SimdBackend> {
match cpu_layer {
Layer::GrandProduct(mle) => {
Layer::GrandProduct(Mle::new(mle.into_evals().into_iter().collect()))
}
Layer::LogUpGeneric {
numerators,
denominators,
} => Layer::LogUpGeneric {
numerators: Mle::new(numerators.into_evals().into_iter().collect()),
denominators: Mle::new(denominators.into_evals().into_iter().collect()),
},
Layer::LogUpMultiplicities {
numerators,
denominators,
} => Layer::LogUpMultiplicities {
numerators: Mle::new(numerators.into_evals().into_iter().collect()),
denominators: Mle::new(denominators.into_evals().into_iter().collect()),
},
Layer::LogUpSingles { denominators } => Layer::LogUpSingles {
denominators: Mle::new(denominators.into_evals().into_iter().collect()),
},
}
}
#[cfg(test)]
mod tests {
use std::iter::zip;
use num_traits::One;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Column, CpuBackend};
use crate::core::channel::Channel;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::gkr_prover::{prove_batch, GkrOps, Layer};
use crate::core::lookups::gkr_verifier::{partially_verify_batch, Gate, GkrArtifact, GkrError};
use crate::core::lookups::mle::Mle;
use crate::core::lookups::utils::Fraction;
use crate::core::test_utils::test_channel;
#[test]
fn gen_eq_evals_matches_cpu() {
let two = BaseField::from(2).into();
let y = [7, 3, 5, 6, 1, 1, 9].map(|v| BaseField::from(v).into());
let eq_evals_cpu = CpuBackend::gen_eq_evals(&y, two);
let eq_evals_simd = SimdBackend::gen_eq_evals(&y, two);
assert_eq!(eq_evals_simd.to_cpu(), *eq_evals_cpu);
}
#[test]
fn gen_eq_evals_with_small_assignment_matches_cpu() {
let two = BaseField::from(2).into();
let y = [7, 3, 5].map(|v| BaseField::from(v).into());
let eq_evals_cpu = CpuBackend::gen_eq_evals(&y, two);
let eq_evals_simd = SimdBackend::gen_eq_evals(&y, two);
assert_eq!(eq_evals_simd.to_cpu(), *eq_evals_cpu);
}
#[test]
fn grand_product_works() -> Result<(), GkrError> {
const N: usize = 1 << 8;
let values = test_channel().draw_felts(N);
let product = values.iter().product();
let col = Mle::<SimdBackend, SecureField>::new(values.into_iter().collect());
let input_layer = Layer::GrandProduct(col.clone());
let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]);
let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut test_channel())?;
assert_eq!(proof.output_claims_by_instance, [vec![product]]);
assert_eq!(
claims_to_verify_by_instance,
[vec![col.eval_at_point(&ood_point)]]
);
Ok(())
}
#[test]
fn logup_with_generic_trace_works() -> Result<(), GkrError> {
const N: usize = 1 << 8;
let mut rng = SmallRng::seed_from_u64(0);
let numerators = (0..N).map(|_| rng.gen()).collect::<Vec<SecureField>>();
let denominators = (0..N).map(|_| rng.gen()).collect::<Vec<SecureField>>();
let sum = zip(&numerators, &denominators)
.map(|(&n, &d)| Fraction::new(n, d))
.sum::<Fraction<SecureField, SecureField>>();
let numerators = Mle::<SimdBackend, SecureField>::new(numerators.into_iter().collect());
let denominators = Mle::<SimdBackend, SecureField>::new(denominators.into_iter().collect());
let input_layer = Layer::LogUpGeneric {
numerators: numerators.clone(),
denominators: denominators.clone(),
};
let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]);
let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?;
assert_eq!(claims_to_verify_by_instance.len(), 1);
assert_eq!(proof.output_claims_by_instance.len(), 1);
assert_eq!(
claims_to_verify_by_instance[0],
[
numerators.eval_at_point(&ood_point),
denominators.eval_at_point(&ood_point)
]
);
assert_eq!(
proof.output_claims_by_instance[0],
[sum.numerator, sum.denominator]
);
Ok(())
}
#[test]
fn logup_with_multiplicities_trace_works() -> Result<(), GkrError> {
const N: usize = 1 << 8;
let mut rng = SmallRng::seed_from_u64(0);
let numerators = (0..N).map(|_| rng.gen()).collect::<Vec<BaseField>>();
let denominators = (0..N).map(|_| rng.gen()).collect::<Vec<SecureField>>();
let sum = zip(&numerators, &denominators)
.map(|(&n, &d)| Fraction::new(n.into(), d))
.sum::<Fraction<SecureField, SecureField>>();
let numerators = Mle::<SimdBackend, BaseField>::new(numerators.into_iter().collect());
let denominators = Mle::<SimdBackend, SecureField>::new(denominators.into_iter().collect());
let input_layer = Layer::LogUpMultiplicities {
numerators: numerators.clone(),
denominators: denominators.clone(),
};
let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]);
let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?;
assert_eq!(claims_to_verify_by_instance.len(), 1);
assert_eq!(proof.output_claims_by_instance.len(), 1);
assert_eq!(
claims_to_verify_by_instance[0],
[
numerators.eval_at_point(&ood_point),
denominators.eval_at_point(&ood_point)
]
);
assert_eq!(
proof.output_claims_by_instance[0],
[sum.numerator, sum.denominator]
);
Ok(())
}
#[test]
fn logup_with_singles_trace_works() -> Result<(), GkrError> {
const N: usize = 1 << 8;
let mut rng = SmallRng::seed_from_u64(0);
let denominators = (0..N).map(|_| rng.gen()).collect::<Vec<SecureField>>();
let sum = denominators
.iter()
.map(|&d| Fraction::new(SecureField::one(), d))
.sum::<Fraction<SecureField, SecureField>>();
let denominators = Mle::<SimdBackend, SecureField>::new(denominators.into_iter().collect());
let input_layer = Layer::LogUpSingles {
denominators: denominators.clone(),
};
let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]);
let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
n_variables_by_instance: _,
} = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?;
assert_eq!(claims_to_verify_by_instance.len(), 1);
assert_eq!(proof.output_claims_by_instance.len(), 1);
assert_eq!(
claims_to_verify_by_instance[0],
[SecureField::one(), denominators.eval_at_point(&ood_point)]
);
assert_eq!(
proof.output_claims_by_instance[0],
[sum.numerator, sum.denominator]
);
Ok(())
}
}

View File

@ -0,0 +1,132 @@
use core::ops::Sub;
use std::iter::zip;
use std::ops::{Add, Mul};
use crate::core::backend::simd::column::SecureColumn;
use crate::core::backend::simd::m31::N_LANES;
use crate::core::backend::simd::qm31::PackedSecureField;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Column, CpuBackend};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::mle::{Mle, MleOps};
impl MleOps<BaseField> for SimdBackend {
fn fix_first_variable(
mle: Mle<Self, BaseField>,
assignment: SecureField,
) -> Mle<Self, SecureField> {
let midpoint = mle.len() / 2;
// Use CPU backend to avoid dealing with instances smaller than `PackedSecureField`.
if midpoint < N_LANES {
let cpu_mle = Mle::<CpuBackend, BaseField>::new(mle.to_cpu());
let cpu_res = cpu_mle.fix_first_variable(assignment);
return Mle::new(cpu_res.into_evals().into_iter().collect());
}
let packed_assignment = PackedSecureField::broadcast(assignment);
let packed_midpoint = midpoint / N_LANES;
let (evals_at_0x, evals_at_1x) = mle.data.split_at(packed_midpoint);
let res = zip(evals_at_0x, evals_at_1x)
.enumerate()
// MLE at points `({0, 1}, rev(bits(i)), v)` for all `v` in `{0, 1}^LOG_N_SIMD_LANES`.
.map(|(_i, (&packed_eval_at_0iv, &packed_eval_at_1iv))| {
fold_packed_mle_evals(packed_assignment, packed_eval_at_0iv, packed_eval_at_1iv)
})
.collect();
Mle::new(res)
}
}
impl MleOps<SecureField> for SimdBackend {
fn fix_first_variable(
mle: Mle<Self, SecureField>,
assignment: SecureField,
) -> Mle<Self, SecureField> {
let midpoint = mle.len() / 2;
// Use CPU backend to avoid dealing with instances smaller than `PackedSecureField`.
if midpoint < N_LANES {
let cpu_mle = Mle::<CpuBackend, SecureField>::new(mle.to_cpu());
let cpu_res = cpu_mle.fix_first_variable(assignment);
return Mle::new(cpu_res.into_evals().into_iter().collect());
}
let packed_midpoint = midpoint / N_LANES;
let packed_assignment = PackedSecureField::broadcast(assignment);
let mut packed_evals = mle.into_evals().data;
for i in 0..packed_midpoint {
// MLE at points `({0, 1}, rev(bits(i)), v)` for all `v` in `{0, 1}^LOG_N_SIMD_LANES`.
let packed_eval_at_0iv = packed_evals[i];
let packed_eval_at_1iv = packed_evals[i + packed_midpoint];
packed_evals[i] =
fold_packed_mle_evals(packed_assignment, packed_eval_at_0iv, packed_eval_at_1iv);
}
packed_evals.truncate(packed_midpoint);
let length = packed_evals.len() * N_LANES;
let data = packed_evals;
Mle::new(SecureColumn { data, length })
}
}
/// Computes all `eq(0, assignment_i) * eval0_i + eq(1, assignment_i) * eval1_i`.
// TODO(andrew): Remove complex trait bounds once we have something like
// AbstractField/AbstractExtensionField traits.
fn fold_packed_mle_evals<
PackedF: Sub<Output = PackedF> + Copy,
PackedEF: Mul<PackedF, Output = PackedEF> + Add<PackedF, Output = PackedEF>,
>(
assignment: PackedEF,
eval0: PackedF,
eval1: PackedF,
) -> PackedEF {
assignment * (eval1 - eval0) + eval0
}
#[cfg(test)]
mod tests {
use itertools::Itertools;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Column, CpuBackend};
use crate::core::channel::Channel;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::mle::Mle;
use crate::core::test_utils::test_channel;
#[test]
fn fix_first_variable_with_secure_field_mle_matches_cpu() {
const N_VARIABLES: u32 = 8;
let values = test_channel().draw_felts(1 << N_VARIABLES);
let mle_simd = Mle::<SimdBackend, SecureField>::new(values.iter().copied().collect());
let mle_cpu = Mle::<CpuBackend, SecureField>::new(values);
let random_assignment = SecureField::from_u32_unchecked(7, 12, 3, 2);
let mle_fixed_cpu = mle_cpu.fix_first_variable(random_assignment);
let mle_fixed_simd = mle_simd.fix_first_variable(random_assignment);
assert_eq!(mle_fixed_simd.into_evals().to_cpu(), *mle_fixed_cpu)
}
#[test]
fn fix_first_variable_with_base_field_mle_matches_cpu() {
const N_VARIABLES: u32 = 8;
let values = (0..1 << N_VARIABLES).map(BaseField::from).collect_vec();
let mle_simd = Mle::<SimdBackend, BaseField>::new(values.iter().copied().collect());
let mle_cpu = Mle::<CpuBackend, BaseField>::new(values);
let random_assignment = SecureField::from_u32_unchecked(7, 12, 3, 2);
let mle_fixed_cpu = mle_cpu.fix_first_variable(random_assignment);
let mle_fixed_simd = mle_simd.fix_first_variable(random_assignment);
assert_eq!(mle_fixed_simd.into_evals().to_cpu(), *mle_fixed_cpu)
}
}

View File

@ -0,0 +1,2 @@
mod gkr;
mod mle;

View File

@ -0,0 +1,666 @@
use std::iter::Sum;
use std::mem::transmute;
use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use std::ptr;
use std::simd::cmp::SimdOrd;
use std::simd::{u32x16, Simd, Swizzle};
use bytemuck::{Pod, Zeroable};
use num_traits::{One, Zero};
use rand::distributions::{Distribution, Standard};
use super::qm31::PackedQM31;
use crate::core::backend::simd::utils::{InterleaveEvens, InterleaveOdds};
use crate::core::fields::m31::{pow2147483645, BaseField, M31, P};
use crate::core::fields::qm31::QM31;
use crate::core::fields::FieldExpOps;
pub const LOG_N_LANES: u32 = 4;
pub const N_LANES: usize = 1 << LOG_N_LANES;
pub const MODULUS: Simd<u32, N_LANES> = Simd::from_array([P; N_LANES]);
pub type PackedBaseField = PackedM31;
/// Holds a vector of unreduced [`M31`] elements in the range `[0, P]`.
///
/// Implemented with [`std::simd`] to support multiple targets (avx512, neon, wasm etc.).
// TODO: Remove `pub` visibility
#[derive(Copy, Clone, Debug)]
#[repr(transparent)]
pub struct PackedM31(Simd<u32, N_LANES>);
impl PackedM31 {
/// Constructs a new instance with all vector elements set to `value`.
pub fn broadcast(M31(value): M31) -> Self {
Self(Simd::splat(value))
}
pub fn from_array(values: [M31; N_LANES]) -> PackedM31 {
Self(Simd::from_array(values.map(|M31(v)| v)))
}
pub fn to_array(self) -> [M31; N_LANES] {
self.reduce().0.to_array().map(M31)
}
/// Reduces each element of the vector to the range `[0, P)`.
fn reduce(self) -> PackedM31 {
Self(Simd::simd_min(self.0, self.0 - MODULUS))
}
/// Interleaves two vectors.
pub fn interleave(self, other: Self) -> (Self, Self) {
let (a, b) = self.0.interleave(other.0);
(Self(a), Self(b))
}
/// Deinterleaves two vectors.
pub fn deinterleave(self, other: Self) -> (Self, Self) {
let (a, b) = self.0.deinterleave(other.0);
(Self(a), Self(b))
}
/// Reverses the order of the elements in the vector.
pub fn reverse(self) -> Self {
Self(self.0.reverse())
}
/// Sums all the elements in the vector.
pub fn pointwise_sum(self) -> M31 {
self.to_array().into_iter().sum()
}
/// Doubles each element in the vector.
pub fn double(self) -> Self {
// TODO: Make more optimal.
self + self
}
pub fn into_simd(self) -> Simd<u32, N_LANES> {
self.0
}
/// # Safety
///
/// Vector elements must be in the range `[0, P]`.
pub unsafe fn from_simd_unchecked(v: Simd<u32, N_LANES>) -> Self {
Self(v)
}
/// # Safety
///
/// Behavior is undefined if the pointer does not have the same alignment as
/// [`PackedM31`]. The loaded `u32` values must be in the range `[0, P]`.
pub unsafe fn load(mem_addr: *const u32) -> Self {
Self(ptr::read(mem_addr as *const u32x16))
}
/// # Safety
///
/// Behavior is undefined if the pointer does not have the same alignment as
/// [`PackedM31`].
pub unsafe fn store(self, dst: *mut u32) {
ptr::write(dst as *mut u32x16, self.0)
}
}
impl Add for PackedM31 {
type Output = Self;
#[inline(always)]
fn add(self, rhs: Self) -> Self::Output {
// Add word by word. Each word is in the range [0, 2P].
let c = self.0 + rhs.0;
// Apply min(c, c-P) to each word.
// When c in [P,2P], then c-P in [0,P] which is always less than [P,2P].
// When c in [0,P-1], then c-P in [2^32-P,2^32-1] which is always greater than [0,P-1].
Self(Simd::simd_min(c, c - MODULUS))
}
}
impl AddAssign for PackedM31 {
#[inline(always)]
fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs;
}
}
impl AddAssign<M31> for PackedM31 {
#[inline(always)]
fn add_assign(&mut self, rhs: M31) {
*self = *self + PackedM31::broadcast(rhs);
}
}
impl Mul for PackedM31 {
type Output = Self;
#[inline(always)]
fn mul(self, rhs: Self) -> Self {
// TODO: Come up with a better approach than `cfg`ing on target_feature.
// TODO: Ensure all these branches get tested in the CI.
cfg_if::cfg_if! {
if #[cfg(all(target_feature = "neon", target_arch = "aarch64"))] {
_mul_neon(self, rhs)
} else if #[cfg(all(target_feature = "simd128", target_arch = "wasm32"))] {
_mul_wasm(self, rhs)
} else if #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] {
_mul_avx512(self, rhs)
} else if #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] {
_mul_avx2(self, rhs)
} else {
_mul_simd(self, rhs)
}
}
}
}
impl Mul<M31> for PackedM31 {
type Output = Self;
#[inline(always)]
fn mul(self, rhs: M31) -> Self::Output {
self * PackedM31::broadcast(rhs)
}
}
impl Add<M31> for PackedM31 {
type Output = PackedM31;
#[inline(always)]
fn add(self, rhs: M31) -> Self::Output {
PackedM31::broadcast(rhs) + self
}
}
impl Add<QM31> for PackedM31 {
type Output = PackedQM31;
#[inline(always)]
fn add(self, rhs: QM31) -> Self::Output {
PackedQM31::broadcast(rhs) + self
}
}
impl Mul<QM31> for PackedM31 {
type Output = PackedQM31;
#[inline(always)]
fn mul(self, rhs: QM31) -> Self::Output {
PackedQM31::broadcast(rhs) * self
}
}
impl MulAssign for PackedM31 {
#[inline(always)]
fn mul_assign(&mut self, rhs: Self) {
*self = *self * rhs;
}
}
impl Neg for PackedM31 {
type Output = Self;
#[inline(always)]
fn neg(self) -> Self::Output {
Self(MODULUS - self.0)
}
}
impl Sub for PackedM31 {
type Output = Self;
#[inline(always)]
fn sub(self, rhs: Self) -> Self::Output {
// Subtract word by word. Each word is in the range [-P, P].
let c = self.0 - rhs.0;
// Apply min(c, c+P) to each word.
// When c in [0,P], then c+P in [P,2P] which is always greater than [0,P].
// When c in [2^32-P,2^32-1], then c+P in [0,P-1] which is always less than
// [2^32-P,2^32-1].
Self(Simd::simd_min(c + MODULUS, c))
}
}
impl SubAssign for PackedM31 {
#[inline(always)]
fn sub_assign(&mut self, rhs: Self) {
*self = *self - rhs;
}
}
impl Zero for PackedM31 {
fn zero() -> Self {
Self(Simd::from_array([0; N_LANES]))
}
fn is_zero(&self) -> bool {
self.to_array().iter().all(M31::is_zero)
}
}
impl One for PackedM31 {
fn one() -> Self {
Self(Simd::<u32, N_LANES>::from_array([1; N_LANES]))
}
}
impl FieldExpOps for PackedM31 {
fn inverse(&self) -> Self {
assert!(!self.is_zero(), "0 has no inverse");
pow2147483645(*self)
}
}
unsafe impl Pod for PackedM31 {}
unsafe impl Zeroable for PackedM31 {
fn zeroed() -> Self {
unsafe { core::mem::zeroed() }
}
}
impl From<[BaseField; N_LANES]> for PackedM31 {
fn from(v: [BaseField; N_LANES]) -> Self {
Self::from_array(v)
}
}
impl From<BaseField> for PackedM31 {
fn from(v: BaseField) -> Self {
Self::broadcast(v)
}
}
impl Distribution<PackedM31> for Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> PackedM31 {
PackedM31::from_array(rng.gen())
}
}
impl Sum for PackedM31 {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.fold(Self::zero(), Add::add)
}
}
/// Returns `a * b`.
#[cfg(target_arch = "aarch64")]
pub(crate) fn _mul_neon(a: PackedM31, b: PackedM31) -> PackedM31 {
use core::arch::aarch64::{int32x2_t, vqdmull_s32};
use std::simd::u32x4;
let [a0, a1, a2, a3, a4, a5, a6, a7]: [int32x2_t; 8] = unsafe { transmute(a) };
let [b0, b1, b2, b3, b4, b5, b6, b7]: [int32x2_t; 8] = unsafe { transmute(b) };
// Each c_i contains |0|prod_lo|prod_hi|0|0|prod_lo|prod_hi|0|
let c0: u32x4 = unsafe { transmute(vqdmull_s32(a0, b0)) };
let c1: u32x4 = unsafe { transmute(vqdmull_s32(a1, b1)) };
let c2: u32x4 = unsafe { transmute(vqdmull_s32(a2, b2)) };
let c3: u32x4 = unsafe { transmute(vqdmull_s32(a3, b3)) };
let c4: u32x4 = unsafe { transmute(vqdmull_s32(a4, b4)) };
let c5: u32x4 = unsafe { transmute(vqdmull_s32(a5, b5)) };
let c6: u32x4 = unsafe { transmute(vqdmull_s32(a6, b6)) };
let c7: u32x4 = unsafe { transmute(vqdmull_s32(a7, b7)) };
// *_lo contain `|prod_lo|0|prod_lo|0|prod_lo0|0|prod_lo|0|`.
// *_hi contain `|0|prod_hi|0|prod_hi|0|prod_hi|0|prod_hi|`.
let (mut c0_c1_lo, c0_c1_hi) = c0.deinterleave(c1);
let (mut c2_c3_lo, c2_c3_hi) = c2.deinterleave(c3);
let (mut c4_c5_lo, c4_c5_hi) = c4.deinterleave(c5);
let (mut c6_c7_lo, c6_c7_hi) = c6.deinterleave(c7);
// *_lo contain `|0|prod_lo|0|prod_lo|0|prod_lo|0|prod_lo|`.
c0_c1_lo >>= 1;
c2_c3_lo >>= 1;
c4_c5_lo >>= 1;
c6_c7_lo >>= 1;
let lo: PackedM31 = unsafe { transmute([c0_c1_lo, c2_c3_lo, c4_c5_lo, c6_c7_lo]) };
let hi: PackedM31 = unsafe { transmute([c0_c1_hi, c2_c3_hi, c4_c5_hi, c6_c7_hi]) };
lo + hi
}
/// Returns `a * b`.
///
/// `b_double` should be in the range `[0, 2P]`.
#[cfg(target_arch = "aarch64")]
pub(crate) fn _mul_doubled_neon(a: PackedM31, b_double: u32x16) -> PackedM31 {
use core::arch::aarch64::{uint32x2_t, vmull_u32};
use std::simd::u32x4;
let [a0, a1, a2, a3, a4, a5, a6, a7]: [uint32x2_t; 8] = unsafe { transmute(a) };
let [b0, b1, b2, b3, b4, b5, b6, b7]: [uint32x2_t; 8] = unsafe { transmute(b_double) };
// Each c_i contains |0|prod_lo|prod_hi|0|0|prod_lo|prod_hi|0|
let c0: u32x4 = unsafe { transmute(vmull_u32(a0, b0)) };
let c1: u32x4 = unsafe { transmute(vmull_u32(a1, b1)) };
let c2: u32x4 = unsafe { transmute(vmull_u32(a2, b2)) };
let c3: u32x4 = unsafe { transmute(vmull_u32(a3, b3)) };
let c4: u32x4 = unsafe { transmute(vmull_u32(a4, b4)) };
let c5: u32x4 = unsafe { transmute(vmull_u32(a5, b5)) };
let c6: u32x4 = unsafe { transmute(vmull_u32(a6, b6)) };
let c7: u32x4 = unsafe { transmute(vmull_u32(a7, b7)) };
// *_lo contain `|prod_lo|0|prod_lo|0|prod_lo0|0|prod_lo|0|`.
// *_hi contain `|0|prod_hi|0|prod_hi|0|prod_hi|0|prod_hi|`.
let (mut c0_c1_lo, c0_c1_hi) = c0.deinterleave(c1);
let (mut c2_c3_lo, c2_c3_hi) = c2.deinterleave(c3);
let (mut c4_c5_lo, c4_c5_hi) = c4.deinterleave(c5);
let (mut c6_c7_lo, c6_c7_hi) = c6.deinterleave(c7);
// *_lo contain `|0|prod_lo|0|prod_lo|0|prod_lo|0|prod_lo|`.
c0_c1_lo >>= 1;
c2_c3_lo >>= 1;
c4_c5_lo >>= 1;
c6_c7_lo >>= 1;
let lo: PackedM31 = unsafe { transmute([c0_c1_lo, c2_c3_lo, c4_c5_lo, c6_c7_lo]) };
let hi: PackedM31 = unsafe { transmute([c0_c1_hi, c2_c3_hi, c4_c5_hi, c6_c7_hi]) };
lo + hi
}
/// Returns `a * b`.
#[cfg(target_arch = "wasm32")]
pub(crate) fn _mul_wasm(a: PackedM31, b: PackedM31) -> PackedM31 {
_mul_doubled_wasm(a, b.0 + b.0)
}
/// Returns `a * b`.
///
/// `b_double` should be in the range `[0, 2P]`.
#[cfg(target_arch = "wasm32")]
pub(crate) fn _mul_doubled_wasm(a: PackedM31, b_double: u32x16) -> PackedM31 {
use core::arch::wasm32::{i64x2_extmul_high_u32x4, i64x2_extmul_low_u32x4, v128};
use std::simd::u32x4;
let [a0, a1, a2, a3]: [v128; 4] = unsafe { transmute(a) };
let [b_double0, b_double1, b_double2, b_double3]: [v128; 4] = unsafe { transmute(b_double) };
let c0_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a0, b_double0)) };
let c0_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a0, b_double0)) };
let c1_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a1, b_double1)) };
let c1_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a1, b_double1)) };
let c2_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a2, b_double2)) };
let c2_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a2, b_double2)) };
let c3_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a3, b_double3)) };
let c3_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a3, b_double3)) };
let (mut c0_even, c0_odd) = c0_lo.deinterleave(c0_hi);
let (mut c1_even, c1_odd) = c1_lo.deinterleave(c1_hi);
let (mut c2_even, c2_odd) = c2_lo.deinterleave(c2_hi);
let (mut c3_even, c3_odd) = c3_lo.deinterleave(c3_hi);
c0_even >>= 1;
c1_even >>= 1;
c2_even >>= 1;
c3_even >>= 1;
let even: PackedM31 = unsafe { transmute([c0_even, c1_even, c2_even, c3_even]) };
let odd: PackedM31 = unsafe { transmute([c0_odd, c1_odd, c2_odd, c3_odd]) };
even + odd
}
/// Returns `a * b`.
#[cfg(target_arch = "x86_64")]
pub(crate) fn _mul_avx512(a: PackedM31, b: PackedM31) -> PackedM31 {
_mul_doubled_avx512(a, b.0 + b.0)
}
/// Returns `a * b`.
///
/// `b_double` should be in the range `[0, 2P]`.
#[cfg(target_arch = "x86_64")]
pub(crate) fn _mul_doubled_avx512(a: PackedM31, b_double: u32x16) -> PackedM31 {
use std::arch::x86_64::{__m512i, _mm512_mul_epu32, _mm512_srli_epi64};
let a: __m512i = unsafe { transmute(a) };
let b_double: __m512i = unsafe { transmute(b_double) };
// Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of
// the first operand.
let a_e = a;
// Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of
// the first operand.
let a_o = unsafe { _mm512_srli_epi64(a, 32) };
let b_dbl_e = b_double;
let b_dbl_o = unsafe { _mm512_srli_epi64(b_double, 32) };
// To compute prod = a * b start by multiplying a_e/odd by b_dbl_e/odd.
let prod_dbl_e: u32x16 = unsafe { transmute(_mm512_mul_epu32(a_e, b_dbl_e)) };
let prod_dbl_o: u32x16 = unsafe { transmute(_mm512_mul_epu32(a_o, b_dbl_o)) };
// The result of a multiplication holds a*b in as 64-bits.
// Each 64b-bit word looks like this:
// 1 31 31 1
// prod_dbl_e - |0|prod_e_h|prod_e_l|0|
// prod_dbl_o - |0|prod_o_h|prod_o_l|0|
// Interleave the even words of prod_dbl_e with the even words of prod_dbl_o:
let mut prod_lo = InterleaveEvens::concat_swizzle(prod_dbl_e, prod_dbl_o);
// prod_lo - |prod_dbl_o_l|0|prod_dbl_e_l|0|
// Divide by 2:
prod_lo >>= 1;
// prod_lo - |0|prod_o_l|0|prod_e_l|
// Interleave the odd words of prod_dbl_e with the odd words of prod_dbl_o:
let prod_hi = InterleaveOdds::concat_swizzle(prod_dbl_e, prod_dbl_o);
// prod_hi - |0|prod_o_h|0|prod_e_h|
PackedM31(prod_lo) + PackedM31(prod_hi)
}
/// Returns `a * b`.
#[cfg(target_arch = "x86_64")]
pub(crate) fn _mul_avx2(a: PackedM31, b: PackedM31) -> PackedM31 {
_mul_doubled_avx2(a, b.0 + b.0)
}
/// Returns `a * b`.
///
/// `b_double` should be in the range `[0, 2P]`.
#[cfg(target_arch = "x86_64")]
pub(crate) fn _mul_doubled_avx2(a: PackedM31, b_double: u32x16) -> PackedM31 {
use std::arch::x86_64::{__m256i, _mm256_mul_epu32, _mm256_srli_epi64};
let [a0, a1]: [__m256i; 2] = unsafe { transmute(a) };
let [b0_dbl, b1_dbl]: [__m256i; 2] = unsafe { transmute(b_double) };
// Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of
// the first operand.
let a0_e = a0;
let a1_e = a1;
// Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of
// the first operand.
let a0_o = unsafe { _mm256_srli_epi64(a0, 32) };
let a1_o = unsafe { _mm256_srli_epi64(a1, 32) };
let b0_dbl_e = b0_dbl;
let b1_dbl_e = b1_dbl;
let b0_dbl_o = unsafe { _mm256_srli_epi64(b0_dbl, 32) };
let b1_dbl_o = unsafe { _mm256_srli_epi64(b1_dbl, 32) };
// To compute prod = a * b start by multiplying a0/1_e/odd by b0/1_e/odd.
let prod0_dbl_e = unsafe { _mm256_mul_epu32(a0_e, b0_dbl_e) };
let prod0_dbl_o = unsafe { _mm256_mul_epu32(a0_o, b0_dbl_o) };
let prod1_dbl_e = unsafe { _mm256_mul_epu32(a1_e, b1_dbl_e) };
let prod1_dbl_o = unsafe { _mm256_mul_epu32(a1_o, b1_dbl_o) };
let prod_dbl_e: u32x16 = unsafe { transmute([prod0_dbl_e, prod1_dbl_e]) };
let prod_dbl_o: u32x16 = unsafe { transmute([prod0_dbl_o, prod1_dbl_o]) };
// The result of a multiplication holds a*b in as 64-bits.
// Each 64b-bit word looks like this:
// 1 31 31 1
// prod_dbl_e - |0|prod_e_h|prod_e_l|0|
// prod_dbl_o - |0|prod_o_h|prod_o_l|0|
// Interleave the even words of prod_dbl_e with the even words of prod_dbl_o:
let mut prod_lo = InterleaveEvens::concat_swizzle(prod_dbl_e, prod_dbl_o);
// prod_lo - |prod_dbl_o_l|0|prod_dbl_e_l|0|
// Divide by 2:
prod_lo >>= 1;
// prod_lo - |0|prod_o_l|0|prod_e_l|
// Interleave the odd words of prod_dbl_e with the odd words of prod_dbl_o:
let prod_hi = InterleaveOdds::concat_swizzle(prod_dbl_e, prod_dbl_o);
// prod_hi - |0|prod_o_h|0|prod_e_h|
PackedM31(prod_lo) + PackedM31(prod_hi)
}
/// Returns `a * b`.
///
/// Should only be used in the absence of a platform specific implementation.
pub(crate) fn _mul_simd(a: PackedM31, b: PackedM31) -> PackedM31 {
_mul_doubled_simd(a, b.0 + b.0)
}
/// Returns `a * b`.
///
/// Should only be used in the absence of a platform specific implementation.
///
/// `b_double` should be in the range `[0, 2P]`.
pub(crate) fn _mul_doubled_simd(a: PackedM31, b_double: u32x16) -> PackedM31 {
const MASK_EVENS: Simd<u64, { N_LANES / 2 }> = Simd::from_array([0xFFFFFFFF; { N_LANES / 2 }]);
// Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of
// the first operand.
let a_e = unsafe { transmute::<_, Simd<u64, { N_LANES / 2 }>>(a.0) & MASK_EVENS };
// Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of
// the first operand.
let a_o = unsafe { transmute::<_, Simd<u64, { N_LANES / 2 }>>(a) >> 32 };
let b_dbl_e = unsafe { transmute::<_, Simd<u64, { N_LANES / 2 }>>(b_double) & MASK_EVENS };
let b_dbl_o = unsafe { transmute::<_, Simd<u64, { N_LANES / 2 }>>(b_double) >> 32 };
// To compute prod = a * b start by multiplying
// a_e/o by b_dbl_e/o.
let prod_e_dbl = a_e * b_dbl_e;
let prod_o_dbl = a_o * b_dbl_o;
// The result of a multiplication holds a*b in as 64-bits.
// Each 64b-bit word looks like this:
// 1 31 31 1
// prod_e_dbl - |0|prod_e_h|prod_e_l|0|
// prod_o_dbl - |0|prod_o_h|prod_o_l|0|
// Interleave the even words of prod_e_dbl with the even words of prod_o_dbl:
// let prod_lows = _mm512_permutex2var_epi32(prod_e_dbl, EVENS_INTERLEAVE_EVENS,
// prod_o_dbl);
// prod_ls - |prod_o_l|0|prod_e_l|0|
let mut prod_lows = InterleaveEvens::concat_swizzle(
unsafe { transmute::<_, Simd<u32, N_LANES>>(prod_e_dbl) },
unsafe { transmute::<_, Simd<u32, N_LANES>>(prod_o_dbl) },
);
// Divide by 2:
prod_lows >>= 1;
// prod_ls - |0|prod_o_l|0|prod_e_l|
// Interleave the odd words of prod_e_dbl with the odd words of prod_o_dbl:
let prod_highs = InterleaveOdds::concat_swizzle(
unsafe { transmute::<_, Simd<u32, N_LANES>>(prod_e_dbl) },
unsafe { transmute::<_, Simd<u32, N_LANES>>(prod_o_dbl) },
);
// prod_hs - |0|prod_o_h|0|prod_e_h|
PackedM31(prod_lows) + PackedM31(prod_highs)
}
#[cfg(test)]
mod tests {
use std::array;
use aligned::{Aligned, A64};
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use super::PackedM31;
use crate::core::fields::m31::BaseField;
use crate::core::fields::FieldExpOps;
#[test]
fn addition_works() {
let mut rng = SmallRng::seed_from_u64(0);
let lhs = rng.gen();
let rhs = rng.gen();
let packed_lhs = PackedM31::from_array(lhs);
let packed_rhs = PackedM31::from_array(rhs);
let res = packed_lhs + packed_rhs;
assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] + rhs[i]));
}
#[test]
fn subtraction_works() {
let mut rng = SmallRng::seed_from_u64(0);
let lhs = rng.gen();
let rhs = rng.gen();
let packed_lhs = PackedM31::from_array(lhs);
let packed_rhs = PackedM31::from_array(rhs);
let res = packed_lhs - packed_rhs;
assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] - rhs[i]));
}
#[test]
fn multiplication_works() {
let mut rng = SmallRng::seed_from_u64(0);
let lhs = rng.gen();
let rhs = rng.gen();
let packed_lhs = PackedM31::from_array(lhs);
let packed_rhs = PackedM31::from_array(rhs);
let res = packed_lhs * packed_rhs;
assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] * rhs[i]));
}
#[test]
fn negation_works() {
let mut rng = SmallRng::seed_from_u64(0);
let values = rng.gen();
let packed_values = PackedM31::from_array(values);
let res = -packed_values;
assert_eq!(res.to_array(), array::from_fn(|i| -values[i]));
}
#[test]
fn load_works() {
let v: Aligned<A64, [u32; 16]> = Aligned(array::from_fn(|i| i as u32));
let res = unsafe { PackedM31::load(v.as_ptr()) };
assert_eq!(res.to_array().map(|v| v.0), *v);
}
#[test]
fn store_works() {
let v = PackedM31::from_array(array::from_fn(BaseField::from));
let mut res: Aligned<A64, [u32; 16]> = Aligned([0; 16]);
unsafe { v.store(res.as_mut_ptr()) };
assert_eq!(*res, v.to_array().map(|v| v.0));
}
#[test]
fn inverse_works() {
let mut rng = SmallRng::seed_from_u64(0);
let values = rng.gen();
let packed_values = PackedM31::from_array(values);
let res = packed_values.inverse();
assert_eq!(res.to_array(), array::from_fn(|i| values[i].inverse()));
}
}

View File

@ -0,0 +1,41 @@
use serde::{Deserialize, Serialize};
use super::{Backend, BackendForChannel};
use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel;
#[cfg(not(target_arch = "wasm32"))]
use crate::core::vcs::poseidon252_merkle::Poseidon252MerkleChannel;
#[cfg(not(target_arch = "wasm32"))]
use crate::core::vcs::poseidon_bls_merkle::PoseidonBLSMerkleChannel;
pub mod accumulation;
pub mod bit_reverse;
pub mod blake2s;
pub mod circle;
pub mod cm31;
pub mod column;
pub mod domain;
pub mod fft;
pub mod fri;
mod grind;
pub mod lookups;
pub mod m31;
#[cfg(not(target_arch = "wasm32"))]
pub mod poseidon252;
pub mod prefix_sum;
pub mod qm31;
pub mod quotients;
mod utils;
pub mod very_packed_m31;
#[cfg(not(target_arch = "wasm32"))]
pub mod poseidon_bls;
#[derive(Copy, Clone, Debug, Deserialize, Serialize)]
pub struct SimdBackend;
impl Backend for SimdBackend {}
impl BackendForChannel<Blake2sMerkleChannel> for SimdBackend {}
#[cfg(not(target_arch = "wasm32"))]
impl BackendForChannel<Poseidon252MerkleChannel> for SimdBackend {}
#[cfg(not(target_arch = "wasm32"))]
impl BackendForChannel<PoseidonBLSMerkleChannel> for SimdBackend {}

View File

@ -0,0 +1,36 @@
use itertools::Itertools;
use starknet_ff::FieldElement as FieldElement252;
use super::SimdBackend;
use crate::core::backend::{Col, Column, ColumnOps};
use crate::core::fields::m31::BaseField;
#[cfg(not(target_arch = "wasm32"))]
use crate::core::vcs::ops::MerkleHasher;
use crate::core::vcs::ops::MerkleOps;
use crate::core::vcs::poseidon252_merkle::Poseidon252MerkleHasher;
impl ColumnOps<FieldElement252> for SimdBackend {
type Column = Vec<FieldElement252>;
fn bit_reverse_column(_column: &mut Self::Column) {
unimplemented!()
}
}
impl MerkleOps<Poseidon252MerkleHasher> for SimdBackend {
// TODO(ShaharS): replace with SIMD implementation.
fn commit_on_layer(
log_size: u32,
prev_layer: Option<&Vec<FieldElement252>>,
columns: &[&Col<Self, BaseField>],
) -> Vec<FieldElement252> {
(0..(1 << log_size))
.map(|i| {
Poseidon252MerkleHasher::hash_node(
prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])),
&columns.iter().map(|column| column.at(i)).collect_vec(),
)
})
.collect()
}
}

View File

@ -0,0 +1,36 @@
use itertools::Itertools;
use ark_bls12_381::Fr as BlsFr;
use super::SimdBackend;
use crate::core::backend::{Col, Column, ColumnOps};
use crate::core::fields::m31::BaseField;
#[cfg(not(target_arch = "wasm32"))]
use crate::core::vcs::ops::MerkleHasher;
use crate::core::vcs::ops::MerkleOps;
use crate::core::vcs::poseidon_bls_merkle::PoseidonBLSMerkleHasher;
impl ColumnOps<BlsFr> for SimdBackend {
type Column = Vec<BlsFr>;
fn bit_reverse_column(_column: &mut Self::Column) {
unimplemented!()
}
}
impl MerkleOps<PoseidonBLSMerkleHasher> for SimdBackend {
// TODO(ShaharS): replace with SIMD implementation.
fn commit_on_layer(
log_size: u32,
prev_layer: Option<&Vec<BlsFr>>,
columns: &[&Col<Self, BaseField>],
) -> Vec<BlsFr> {
(0..(1 << log_size))
.map(|i| {
PoseidonBLSMerkleHasher::hash_node(
prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])),
&columns.iter().map(|column| column.at(i)).collect_vec(),
)
})
.collect()
}
}

View File

@ -0,0 +1,188 @@
use std::iter::zip;
use std::ops::{AddAssign, Sub};
use itertools::{izip, Itertools};
use num_traits::Zero;
use crate::core::backend::simd::m31::{PackedBaseField, N_LANES};
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Col, Column};
use crate::core::fields::m31::BaseField;
use crate::core::utils::{
bit_reverse, circle_domain_order_to_coset_order, coset_order_to_circle_domain_order,
};
/// Performs a inclusive prefix sum on values in `Coset` order when provided
/// with evaluations in bit-reversed `CircleDomain` order.
///
/// Based on parallel Blelloch prefix sum:
/// <https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda>
pub fn inclusive_prefix_sum(
bit_rev_circle_domain_evals: Col<SimdBackend, BaseField>,
) -> Col<SimdBackend, BaseField> {
if bit_rev_circle_domain_evals.len() < N_LANES * 4 {
return inclusive_prefix_sum_slow(bit_rev_circle_domain_evals);
}
let mut res = bit_rev_circle_domain_evals;
let packed_len = res.data.len();
let (l_half, r_half) = res.data.split_at_mut(packed_len / 2);
// Up Sweep
// ========
// Handle the first two up sweep rounds manually.
// Required due different ordering of `CircleDomain` and `Coset`.
// Evaluations are provided in bit-reversed `CircleDomain` order.
for ([l0, l1], [r0, r1]) in izip!(l_half.array_chunks_mut(), r_half.array_chunks_mut().rev()) {
let (mut half_coset0_lo, half_coset1_hi_rev) = l0.deinterleave(*l1);
let half_coset1_hi = half_coset1_hi_rev.reverse();
let (mut half_coset0_hi, half_coset1_lo_rev) = r0.deinterleave(*r1);
let half_coset1_lo = half_coset1_lo_rev.reverse();
up_sweep_val(&mut half_coset0_lo, half_coset1_lo);
up_sweep_val(&mut half_coset0_hi, half_coset1_hi);
*l0 = half_coset0_lo;
*l1 = half_coset0_hi;
*r0 = half_coset1_lo;
*r1 = half_coset1_hi;
}
let half_coset0_sums: &mut [PackedBaseField] = l_half;
for i in 0..half_coset0_sums.len() / 2 {
let lo_index = i * 2;
let hi_index = half_coset0_sums.len() - 1 - i * 2;
let hi = half_coset0_sums[hi_index];
up_sweep_val(&mut half_coset0_sums[lo_index], hi)
}
// Handle remaining up sweep rounds.
let mut chunk_size = half_coset0_sums.len() / 2;
while chunk_size > 1 {
let (lows, highs) = half_coset0_sums.split_at_mut(chunk_size);
zip(lows.array_chunks_mut(), highs.array_chunks())
.for_each(|([lo, _], [hi, _])| up_sweep_val(lo, *hi));
chunk_size /= 2;
}
// Up sweep the last SIMD vector.
let mut first_vec = half_coset0_sums.first().unwrap().to_array();
let mut chunk_size = first_vec.len() / 2;
while chunk_size > 0 {
let (lows, highs) = first_vec.split_at_mut(chunk_size);
zip(lows, highs).for_each(|(lo, hi)| up_sweep_val(lo, *hi));
chunk_size /= 2;
}
// Down Sweep
// ==========
// Down sweep the last SIMD vector.
let mut chunk_size = 1;
while chunk_size < first_vec.len() {
let (lows, highs) = first_vec.split_at_mut(chunk_size);
zip(lows, highs).for_each(|(lo, hi)| down_sweep_val(lo, hi));
chunk_size *= 2;
}
// Re-insert the SIMD vector.
*half_coset0_sums.first_mut().unwrap() = first_vec.into();
// Handle remaining down sweep rounds (except first two).
let mut chunk_size = 2;
while chunk_size < half_coset0_sums.len() {
let (lows, highs) = half_coset0_sums.split_at_mut(chunk_size);
zip(lows.array_chunks_mut(), highs.array_chunks_mut())
.for_each(|([lo, _], [hi, _])| down_sweep_val(lo, hi));
chunk_size *= 2;
}
// Handle last two down sweep rounds manually.
// Required due different ordering of `CircleDomain` and `Coset`.
// Evaluations must be returned in bit-reversed `CircleDomain` order.
for i in 0..half_coset0_sums.len() / 2 {
let lo_index = i * 2;
let hi_index = half_coset0_sums.len() - 1 - i * 2;
let (mut lo, mut hi) = (half_coset0_sums[lo_index], half_coset0_sums[hi_index]);
down_sweep_val(&mut lo, &mut hi);
(half_coset0_sums[lo_index], half_coset0_sums[hi_index]) = (lo, hi);
}
for ([l0, l1], [r0, r1]) in izip!(l_half.array_chunks_mut(), r_half.array_chunks_mut().rev()) {
let mut half_coset0_lo = *l0;
let mut half_coset1_lo = *r0;
down_sweep_val(&mut half_coset0_lo, &mut half_coset1_lo);
let mut half_coset0_hi = *l1;
let mut half_coset1_hi = *r1;
down_sweep_val(&mut half_coset0_hi, &mut half_coset1_hi);
(*l0, *l1) = half_coset0_lo.interleave(half_coset1_hi.reverse());
(*r0, *r1) = half_coset0_hi.interleave(half_coset1_lo.reverse());
}
res
}
fn up_sweep_val<F: AddAssign + Copy>(lo: &mut F, hi: F) {
*lo += hi;
}
fn down_sweep_val<F: Sub<Output = F> + Copy>(lo: &mut F, hi: &mut F) {
(*lo, *hi) = (*lo - *hi, *lo)
}
fn inclusive_prefix_sum_slow(
bit_rev_circle_domain_evals: Col<SimdBackend, BaseField>,
) -> Col<SimdBackend, BaseField> {
// Obtain values in coset order.
let mut coset_order_eval = bit_rev_circle_domain_evals.into_cpu_vec();
bit_reverse(&mut coset_order_eval);
coset_order_eval = circle_domain_order_to_coset_order(&coset_order_eval);
let coset_order_prefix_sum = coset_order_eval
.into_iter()
.scan(BaseField::zero(), |acc, v| {
*acc += v;
Some(*acc)
})
.collect_vec();
let mut circle_domain_order_eval = coset_order_to_circle_domain_order(&coset_order_prefix_sum);
bit_reverse(&mut circle_domain_order_eval);
circle_domain_order_eval.into_iter().collect()
}
#[cfg(test)]
mod tests {
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use test_log::test;
use super::inclusive_prefix_sum;
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum_slow;
use crate::core::backend::Column;
#[test]
fn exclusive_prefix_sum_simd_with_log_size_3_works() {
const LOG_N: u32 = 3;
let mut rng = SmallRng::seed_from_u64(0);
let evals: BaseColumn = (0..1 << LOG_N).map(|_| rng.gen()).collect();
let expected = inclusive_prefix_sum_slow(evals.clone());
let res = inclusive_prefix_sum(evals);
assert_eq!(res.to_cpu(), expected.to_cpu());
}
#[test]
fn exclusive_prefix_sum_simd_with_log_size_6_works() {
const LOG_N: u32 = 6;
let mut rng = SmallRng::seed_from_u64(0);
let evals: BaseColumn = (0..1 << LOG_N).map(|_| rng.gen()).collect();
let expected = inclusive_prefix_sum_slow(evals.clone());
let res = inclusive_prefix_sum(evals);
assert_eq!(res.to_cpu(), expected.to_cpu());
}
#[test]
fn exclusive_prefix_sum_simd_with_log_size_8_works() {
const LOG_N: u32 = 8;
let mut rng = SmallRng::seed_from_u64(0);
let evals: BaseColumn = (0..1 << LOG_N).map(|_| rng.gen()).collect();
let expected = inclusive_prefix_sum_slow(evals.clone());
let res = inclusive_prefix_sum(evals);
assert_eq!(res.to_cpu(), expected.to_cpu());
}
}

View File

@ -0,0 +1,357 @@
use std::array;
use std::iter::Sum;
use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use bytemuck::{Pod, Zeroable};
use num_traits::{One, Zero};
use rand::distributions::{Distribution, Standard};
use super::cm31::PackedCM31;
use super::m31::{PackedM31, N_LANES};
use crate::core::fields::qm31::QM31;
use crate::core::fields::FieldExpOps;
pub type PackedSecureField = PackedQM31;
/// SIMD implementation of [`QM31`].
#[derive(Copy, Clone, Debug)]
pub struct PackedQM31(pub [PackedCM31; 2]);
impl PackedQM31 {
/// Constructs a new instance with all vector elements set to `value`.
pub fn broadcast(value: QM31) -> Self {
Self([
PackedCM31::broadcast(value.0),
PackedCM31::broadcast(value.1),
])
}
/// Returns all `a` values such that each vector element is represented as `a + bu`.
pub fn a(&self) -> PackedCM31 {
self.0[0]
}
/// Returns all `b` values such that each vector element is represented as `a + bu`.
pub fn b(&self) -> PackedCM31 {
self.0[1]
}
pub fn to_array(&self) -> [QM31; N_LANES] {
let a = self.a().to_array();
let b = self.b().to_array();
array::from_fn(|i| QM31(a[i], b[i]))
}
pub fn from_array(values: [QM31; N_LANES]) -> Self {
let a = values.map(|v| v.0);
let b = values.map(|v| v.1);
Self([PackedCM31::from_array(a), PackedCM31::from_array(b)])
}
/// Interleaves two vectors.
pub fn interleave(self, other: Self) -> (Self, Self) {
let Self([a_evens, b_evens]) = self;
let Self([a_odds, b_odds]) = other;
let (a_lhs, a_rhs) = a_evens.interleave(a_odds);
let (b_lhs, b_rhs) = b_evens.interleave(b_odds);
(Self([a_lhs, b_lhs]), Self([a_rhs, b_rhs]))
}
/// Deinterleaves two vectors.
pub fn deinterleave(self, other: Self) -> (Self, Self) {
let Self([a_lhs, b_lhs]) = self;
let Self([a_rhs, b_rhs]) = other;
let (a_evens, a_odds) = a_lhs.deinterleave(a_rhs);
let (b_evens, b_odds) = b_lhs.deinterleave(b_rhs);
(Self([a_evens, b_evens]), Self([a_odds, b_odds]))
}
/// Sums all the elements in the vector.
pub fn pointwise_sum(self) -> QM31 {
self.to_array().into_iter().sum()
}
/// Doubles each element in the vector.
pub fn double(self) -> Self {
let Self([a, b]) = self;
Self([a.double(), b.double()])
}
/// Returns vectors `a, b, c, d` such that element `i` is represented as
/// `QM31(a_i, b_i, c_i, d_i)`.
pub fn into_packed_m31s(self) -> [PackedM31; 4] {
let Self([PackedCM31([a, b]), PackedCM31([c, d])]) = self;
[a, b, c, d]
}
/// Creates an instance from vectors `a, b, c, d` such that element `i`
/// is represented as `QM31(a_i, b_i, c_i, d_i)`.
pub fn from_packed_m31s([a, b, c, d]: [PackedM31; 4]) -> Self {
Self([PackedCM31([a, b]), PackedCM31([c, d])])
}
}
impl Add for PackedQM31 {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
Self([self.a() + rhs.a(), self.b() + rhs.b()])
}
}
impl Sub for PackedQM31 {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
Self([self.a() - rhs.a(), self.b() - rhs.b()])
}
}
impl Mul for PackedQM31 {
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
// Compute using Karatsuba.
// (a + ub) * (c + ud) =
// (ac + (2+i)bd) + (ad + bc)u =
// ac + 2bd + ibd + (ad + bc)u.
let ac = self.a() * rhs.a();
let bd = self.b() * rhs.b();
let bd_times_1_plus_i = PackedCM31([bd.a() - bd.b(), bd.a() + bd.b()]);
// Computes ac + bd.
let ac_p_bd = ac + bd;
// Computes ad + bc.
let ad_p_bc = (self.a() + self.b()) * (rhs.a() + rhs.b()) - ac_p_bd;
// ac + 2bd + ibd =
// ac + bd + bd + ibd
let l = PackedCM31([
ac_p_bd.a() + bd_times_1_plus_i.a(),
ac_p_bd.b() + bd_times_1_plus_i.b(),
]);
Self([l, ad_p_bc])
}
}
impl Zero for PackedQM31 {
fn zero() -> Self {
Self([PackedCM31::zero(), PackedCM31::zero()])
}
fn is_zero(&self) -> bool {
self.a().is_zero() && self.b().is_zero()
}
}
impl One for PackedQM31 {
fn one() -> Self {
Self([PackedCM31::one(), PackedCM31::zero()])
}
}
impl AddAssign for PackedQM31 {
fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs;
}
}
impl MulAssign for PackedQM31 {
fn mul_assign(&mut self, rhs: Self) {
*self = *self * rhs;
}
}
impl FieldExpOps for PackedQM31 {
fn inverse(&self) -> Self {
assert!(!self.is_zero(), "0 has no inverse");
// (a + bu)^-1 = (a - bu) / (a^2 - (2+i)b^2).
let b2 = self.b().square();
let ib2 = PackedCM31([-b2.b(), b2.a()]);
let denom = self.a().square() - (b2 + b2 + ib2);
let denom_inverse = denom.inverse();
Self([self.a() * denom_inverse, -self.b() * denom_inverse])
}
}
impl Add<PackedM31> for PackedQM31 {
type Output = Self;
fn add(self, rhs: PackedM31) -> Self::Output {
Self([self.a() + rhs, self.b()])
}
}
impl Mul<PackedM31> for PackedQM31 {
type Output = Self;
fn mul(self, rhs: PackedM31) -> Self::Output {
let Self([a, b]) = self;
Self([a * rhs, b * rhs])
}
}
impl Mul<PackedCM31> for PackedQM31 {
type Output = Self;
fn mul(self, rhs: PackedCM31) -> Self::Output {
let Self([a, b]) = self;
Self([a * rhs, b * rhs])
}
}
impl Sub<PackedM31> for PackedQM31 {
type Output = Self;
fn sub(self, rhs: PackedM31) -> Self::Output {
let Self([a, b]) = self;
Self([a - rhs, b])
}
}
impl Add<QM31> for PackedQM31 {
type Output = Self;
fn add(self, rhs: QM31) -> Self::Output {
self + PackedQM31::broadcast(rhs)
}
}
impl Sub<QM31> for PackedQM31 {
type Output = Self;
fn sub(self, rhs: QM31) -> Self::Output {
self - PackedQM31::broadcast(rhs)
}
}
impl Mul<QM31> for PackedQM31 {
type Output = Self;
fn mul(self, rhs: QM31) -> Self::Output {
self * PackedQM31::broadcast(rhs)
}
}
impl SubAssign for PackedQM31 {
fn sub_assign(&mut self, rhs: Self) {
*self = *self - rhs;
}
}
unsafe impl Pod for PackedQM31 {}
unsafe impl Zeroable for PackedQM31 {
fn zeroed() -> Self {
unsafe { core::mem::zeroed() }
}
}
impl Sum for PackedQM31 {
fn sum<I>(mut iter: I) -> Self
where
I: Iterator<Item = Self>,
{
let first = iter.next().unwrap_or_else(Self::zero);
iter.fold(first, |a, b| a + b)
}
}
impl<'a> Sum<&'a Self> for PackedQM31 {
fn sum<I>(iter: I) -> Self
where
I: Iterator<Item = &'a Self>,
{
iter.copied().sum()
}
}
impl Neg for PackedQM31 {
type Output = Self;
fn neg(self) -> Self::Output {
let Self([a, b]) = self;
Self([-a, -b])
}
}
impl Distribution<PackedQM31> for Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> PackedQM31 {
PackedQM31::from_array(rng.gen())
}
}
impl From<PackedM31> for PackedQM31 {
fn from(value: PackedM31) -> Self {
PackedQM31::from_packed_m31s([
value,
PackedM31::zero(),
PackedM31::zero(),
PackedM31::zero(),
])
}
}
impl From<QM31> for PackedQM31 {
fn from(value: QM31) -> Self {
PackedQM31::broadcast(value)
}
}
#[cfg(test)]
mod tests {
use std::array;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use crate::core::backend::simd::qm31::PackedQM31;
#[test]
fn addition_works() {
let mut rng = SmallRng::seed_from_u64(0);
let lhs = rng.gen();
let rhs = rng.gen();
let packed_lhs = PackedQM31::from_array(lhs);
let packed_rhs = PackedQM31::from_array(rhs);
let res = packed_lhs + packed_rhs;
assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] + rhs[i]));
}
#[test]
fn subtraction_works() {
let mut rng = SmallRng::seed_from_u64(0);
let lhs = rng.gen();
let rhs = rng.gen();
let packed_lhs = PackedQM31::from_array(lhs);
let packed_rhs = PackedQM31::from_array(rhs);
let res = packed_lhs - packed_rhs;
assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] - rhs[i]));
}
#[test]
fn multiplication_works() {
let mut rng = SmallRng::seed_from_u64(0);
let lhs = rng.gen();
let rhs = rng.gen();
let packed_lhs = PackedQM31::from_array(lhs);
let packed_rhs = PackedQM31::from_array(rhs);
let res = packed_lhs * packed_rhs;
assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] * rhs[i]));
}
#[test]
fn negation_works() {
let mut rng = SmallRng::seed_from_u64(0);
let values = rng.gen();
let packed_values = PackedQM31::from_array(values);
let res = -packed_values;
assert_eq!(res.to_array(), values.map(|v| -v));
}
}

View File

@ -0,0 +1,314 @@
use itertools::{izip, zip_eq, Itertools};
use num_traits::Zero;
use tracing::{span, Level};
use super::cm31::PackedCM31;
use super::column::CM31Column;
use super::domain::CircleDomainBitRevIterator;
use super::m31::{PackedBaseField, LOG_N_LANES, N_LANES};
use super::qm31::PackedSecureField;
use super::SimdBackend;
use crate::core::backend::cpu::quotients::{batch_random_coeffs, column_line_coeffs};
use crate::core::backend::Column;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::{SecureColumnByCoords, SECURE_EXTENSION_DEGREE};
use crate::core::fields::FieldExpOps;
use crate::core::pcs::quotients::{ColumnSampleBatch, QuotientOps};
use crate::core::poly::circle::{CircleDomain, CircleEvaluation, PolyOps, SecureEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::bit_reverse;
pub struct QuotientConstants {
pub line_coeffs: Vec<Vec<(SecureField, SecureField, SecureField)>>,
pub batch_random_coeffs: Vec<SecureField>,
pub denominator_inverses: Vec<CM31Column>,
}
impl QuotientOps for SimdBackend {
fn accumulate_quotients(
domain: CircleDomain,
columns: &[&CircleEvaluation<Self, BaseField, BitReversedOrder>],
random_coeff: SecureField,
sample_batches: &[ColumnSampleBatch],
log_blowup_factor: u32,
) -> SecureEvaluation<Self, BitReversedOrder> {
// Split the domain into a subdomain and a shift coset.
// TODO(spapini): Move to the caller when Columns support slices.
let (subdomain, mut subdomain_shifts) = domain.split(log_blowup_factor);
// Bit reverse the shifts.
// Since we traverse the domain in bit-reversed order, we need bit-reverse the shifts.
// To see why, consider the index of a point in the natural order of the domain
// (least to most):
// b0 b1 b2 b3 b4 b5
// b0 adds G, b1 adds 2G, etc.. (b5 is special and flips the sign of the point).
// Splitting the domain to 4 parts yields:
// subdomain: b2 b3 b4 b5, shifts: b0 b1.
// b2 b3 b4 b5 is indeed a circle domain, with a bigger jump.
// Traversing the domain in bit-reversed order, after we finish with b5, b4, b3, b2,
// we need to change b1 and then b0. This is the bit reverse of the shift b0 b1.
bit_reverse(&mut subdomain_shifts);
let (span, mut extended_eval, subeval_polys) = accumulate_quotients_on_subdomain(
subdomain,
sample_batches,
random_coeff,
columns,
domain,
);
// Extend the evaluation to the full domain.
// TODO(spapini): Try to optimize out all these copies.
for (ci, &c) in subdomain_shifts.iter().enumerate() {
let subdomain = subdomain.shift(c);
let twiddles = SimdBackend::precompute_twiddles(subdomain.half_coset);
#[allow(clippy::needless_range_loop)]
for i in 0..SECURE_EXTENSION_DEGREE {
// Sanity check.
let eval = subeval_polys[i].evaluate_with_twiddles(subdomain, &twiddles);
extended_eval.columns[i].data[(ci * eval.data.len())..((ci + 1) * eval.data.len())]
.copy_from_slice(&eval.data);
}
}
span.exit();
SecureEvaluation::new(domain, extended_eval)
}
}
fn accumulate_quotients_on_subdomain(
subdomain: CircleDomain,
sample_batches: &[ColumnSampleBatch],
random_coeff: SecureField,
columns: &[&CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>],
domain: CircleDomain,
) -> (
span::EnteredSpan,
SecureColumnByCoords<SimdBackend>,
[crate::core::poly::circle::CirclePoly<SimdBackend>; 4],
) {
assert!(subdomain.log_size() >= LOG_N_LANES + 2);
let mut values =
unsafe { SecureColumnByCoords::<SimdBackend>::uninitialized(subdomain.size()) };
let quotient_constants = quotient_constants(sample_batches, random_coeff, subdomain);
let span = span!(Level::INFO, "Quotient accumulation").entered();
for (quad_row, points) in CircleDomainBitRevIterator::new(subdomain)
.array_chunks::<4>()
.enumerate()
{
// TODO(spapini): Use optimized domain iteration.
let (y01, _) = points[0].y.deinterleave(points[1].y);
let (y23, _) = points[2].y.deinterleave(points[3].y);
let (spaced_ys, _) = y01.deinterleave(y23);
let row_accumulator = accumulate_row_quotients(
sample_batches,
columns,
&quotient_constants,
quad_row,
spaced_ys,
);
#[allow(clippy::needless_range_loop)]
for i in 0..4 {
unsafe { values.set_packed((quad_row << 2) + i, row_accumulator[i]) };
}
}
span.exit();
let span = span!(Level::INFO, "Quotient extension").entered();
// Extend the evaluation to the full domain.
let extended_eval =
unsafe { SecureColumnByCoords::<SimdBackend>::uninitialized(domain.size()) };
let mut i = 0;
let values = values.columns;
let twiddles = SimdBackend::precompute_twiddles(subdomain.half_coset);
let subeval_polys = values.map(|c| {
i += 1;
CircleEvaluation::<SimdBackend, BaseField, BitReversedOrder>::new(subdomain, c)
.interpolate_with_twiddles(&twiddles)
});
(span, extended_eval, subeval_polys)
}
/// Accumulates the quotients for 4 * N_LANES rows at a time.
/// spaced_ys - y values for N_LANES points in the domain, in jumps of 4.
pub fn accumulate_row_quotients(
sample_batches: &[ColumnSampleBatch],
columns: &[&CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>],
quotient_constants: &QuotientConstants,
quad_row: usize,
spaced_ys: PackedBaseField,
) -> [PackedSecureField; 4] {
let mut row_accumulator = [PackedSecureField::zero(); 4];
for (sample_batch, line_coeffs, batch_coeff, denominator_inverses) in izip!(
sample_batches,
&quotient_constants.line_coeffs,
&quotient_constants.batch_random_coeffs,
&quotient_constants.denominator_inverses
) {
let mut numerator = [PackedSecureField::zero(); 4];
for ((column_index, _), (a, b, c)) in zip_eq(&sample_batch.columns_and_values, line_coeffs)
{
let column = &columns[*column_index];
let cvalues: [_; 4] = std::array::from_fn(|i| {
PackedSecureField::broadcast(*c) * column.data[(quad_row << 2) + i]
});
// The numerator is the line equation:
// c * value - a * point.y - b;
// Note that a, b, c were already multilpied by random_coeff^i.
// See [column_line_coeffs()] for more details.
// This is why we only add here.
// 4 consecutive point in the domain in bit reversed order are:
// P, -P, P + H, -P + H.
// H being the half point (-1,0). The y values for these are
// P.y, -P.y, -P.y, P.y.
// We use this fact to save multiplications.
// spaced_ys are the y value in jumps of 4:
// P0.y, P1.y, P2.y, ...
let spaced_ay = PackedSecureField::broadcast(*a) * spaced_ys;
// t0:t1 = a*P0.y, -a*P0.y, a*P1.y, -a*P1.y, ...
let (t0, t1) = spaced_ay.interleave(-spaced_ay);
// t2:t3:t4:t5 = a*P0.y, -a*P0.y, -a*P0.y, a*P0.y, a*P1.y, -a*P1.y, ...
let (t2, t3) = t0.interleave(-t0);
let (t4, t5) = t1.interleave(-t1);
let ay = [t2, t3, t4, t5];
for i in 0..4 {
numerator[i] += cvalues[i] - ay[i] - PackedSecureField::broadcast(*b);
}
}
for i in 0..4 {
row_accumulator[i] = row_accumulator[i] * PackedSecureField::broadcast(*batch_coeff)
+ numerator[i] * denominator_inverses.data[(quad_row << 2) + i];
}
}
row_accumulator
}
fn denominator_inverses(
sample_batches: &[ColumnSampleBatch],
domain: CircleDomain,
) -> Vec<CM31Column> {
// We want a P to be on a line that passes through a point Pr + uPi in QM31^2, and its conjugate
// Pr - uPi. Thus, Pr - P is parallel to Pi. Or, (Pr - P).x * Pi.y - (Pr - P).y * Pi.x = 0.
let flat_denominators: CM31Column = sample_batches
.iter()
.flat_map(|sample_batch| {
// Extract Pr, Pi.
let prx = PackedCM31::broadcast(sample_batch.point.x.0);
let pry = PackedCM31::broadcast(sample_batch.point.y.0);
let pix = PackedCM31::broadcast(sample_batch.point.x.1);
let piy = PackedCM31::broadcast(sample_batch.point.y.1);
// Line equation through pr +-u pi.
// (p-pr)*
CircleDomainBitRevIterator::new(domain)
.map(|points| (prx - points.x) * piy - (pry - points.y) * pix)
.collect_vec()
})
.collect();
let mut flat_denominator_inverses =
unsafe { CM31Column::uninitialized(flat_denominators.len()) };
FieldExpOps::batch_inverse(
&flat_denominators.data,
&mut flat_denominator_inverses.data[..],
);
flat_denominator_inverses
.data
.chunks(domain.size() / N_LANES)
.map(|denominator_inverses| denominator_inverses.iter().copied().collect())
.collect()
}
fn quotient_constants(
sample_batches: &[ColumnSampleBatch],
random_coeff: SecureField,
domain: CircleDomain,
) -> QuotientConstants {
let _span = span!(Level::INFO, "Quotient constants").entered();
let line_coeffs = column_line_coeffs(sample_batches, random_coeff);
let batch_random_coeffs = batch_random_coeffs(sample_batches, random_coeff);
let denominator_inverses = denominator_inverses(sample_batches, domain);
QuotientConstants {
line_coeffs,
batch_random_coeffs,
denominator_inverses,
}
}
#[cfg(test)]
mod tests {
use itertools::Itertools;
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Column, CpuBackend};
use crate::core::circle::SECURE_FIELD_CIRCLE_GEN;
use crate::core::fields::m31::BaseField;
use crate::core::pcs::quotients::{ColumnSampleBatch, QuotientOps};
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::qm31;
#[test]
fn test_accumulate_quotients() {
const LOG_SIZE: u32 = 8;
const LOG_BLOWUP_FACTOR: u32 = 1;
let small_domain = CanonicCoset::new(LOG_SIZE).circle_domain();
let domain = CanonicCoset::new(LOG_SIZE + LOG_BLOWUP_FACTOR).circle_domain();
let e0: BaseColumn = (0..small_domain.size()).map(BaseField::from).collect();
let e1: BaseColumn = (0..small_domain.size())
.map(|i| BaseField::from(2 * i))
.collect();
let polys = vec![
CircleEvaluation::<SimdBackend, BaseField, BitReversedOrder>::new(small_domain, e0)
.interpolate(),
CircleEvaluation::<SimdBackend, BaseField, BitReversedOrder>::new(small_domain, e1)
.interpolate(),
];
let columns = vec![polys[0].evaluate(domain), polys[1].evaluate(domain)];
let random_coeff = qm31!(1, 2, 3, 4);
let a = polys[0].eval_at_point(SECURE_FIELD_CIRCLE_GEN);
let b = polys[1].eval_at_point(SECURE_FIELD_CIRCLE_GEN);
let samples = vec![ColumnSampleBatch {
point: SECURE_FIELD_CIRCLE_GEN,
columns_and_values: vec![(0, a), (1, b)],
}];
let cpu_columns = columns
.iter()
.map(|c| {
CircleEvaluation::<CpuBackend, _, BitReversedOrder>::new(
c.domain,
c.values.to_cpu(),
)
})
.collect::<Vec<_>>();
let cpu_result = CpuBackend::accumulate_quotients(
domain,
&cpu_columns.iter().collect_vec(),
random_coeff,
&samples,
LOG_BLOWUP_FACTOR,
)
.values
.to_vec();
let res = SimdBackend::accumulate_quotients(
domain,
&columns.iter().collect_vec(),
random_coeff,
&samples,
LOG_BLOWUP_FACTOR,
)
.values
.to_vec();
assert_eq!(res, cpu_result);
}
}

View File

@ -0,0 +1,52 @@
use std::simd::Swizzle;
/// Used with [`Swizzle::concat_swizzle`] to interleave the even values of two vectors.
pub struct InterleaveEvens;
impl<const N: usize> Swizzle<N> for InterleaveEvens {
const INDEX: [usize; N] = parity_interleave(false);
}
/// Used with [`Swizzle::concat_swizzle`] to interleave the odd values of two vectors.
pub struct InterleaveOdds;
impl<const N: usize> Swizzle<N> for InterleaveOdds {
const INDEX: [usize; N] = parity_interleave(true);
}
const fn parity_interleave<const N: usize>(odd: bool) -> [usize; N] {
let mut res = [0; N];
let mut i = 0;
while i < N {
res[i] = (i % 2) * N + (i / 2) * 2 + if odd { 1 } else { 0 };
i += 1;
}
res
}
#[cfg(test)]
mod tests {
use std::simd::{u32x4, Swizzle};
use super::{InterleaveEvens, InterleaveOdds};
#[test]
fn interleave_evens() {
let lo = u32x4::from_array([0, 1, 2, 3]);
let hi = u32x4::from_array([4, 5, 6, 7]);
let res = InterleaveEvens::concat_swizzle(lo, hi);
assert_eq!(res, u32x4::from_array([0, 4, 2, 6]));
}
#[test]
fn interleave_odds() {
let lo = u32x4::from_array([0, 1, 2, 3]);
let hi = u32x4::from_array([4, 5, 6, 7]);
let res = InterleaveOdds::concat_swizzle(lo, hi);
assert_eq!(res, u32x4::from_array([1, 5, 3, 7]));
}
}

View File

@ -0,0 +1,222 @@
use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub};
use bytemuck::{Pod, Zeroable};
use num_traits::{One, Zero};
use super::cm31::PackedCM31;
use super::m31::{PackedM31, N_LANES};
use super::qm31::PackedQM31;
use crate::core::fields::cm31::CM31;
use crate::core::fields::m31::{pow2147483645, M31};
use crate::core::fields::qm31::QM31;
use crate::core::fields::FieldExpOps;
pub const LOG_N_VERY_PACKED_ELEMS: u32 = 1;
pub const N_VERY_PACKED_ELEMS: usize = 1 << LOG_N_VERY_PACKED_ELEMS;
#[derive(Copy, Clone, Debug)]
#[repr(transparent)]
pub struct Vectorized<A, const N: usize>(pub [A; N]);
impl<A, const N: usize> Vectorized<A, N> {
pub fn from_fn<F>(cb: F) -> Self
where
F: FnMut(usize) -> A,
{
Vectorized(std::array::from_fn(cb))
}
}
unsafe impl<A, const N: usize> Zeroable for Vectorized<A, N> {
fn zeroed() -> Self {
unsafe { core::mem::zeroed() }
}
}
unsafe impl<A: Pod, const N: usize> Pod for Vectorized<A, N> {}
pub type VeryPackedM31 = Vectorized<PackedM31, N_VERY_PACKED_ELEMS>;
pub type VeryPackedCM31 = Vectorized<PackedCM31, N_VERY_PACKED_ELEMS>;
pub type VeryPackedQM31 = Vectorized<PackedQM31, N_VERY_PACKED_ELEMS>;
pub type VeryPackedBaseField = VeryPackedM31;
pub type VeryPackedSecureField = VeryPackedQM31;
impl VeryPackedM31 {
pub fn broadcast(value: M31) -> Self {
Self::from_fn(|_| PackedM31::broadcast(value))
}
pub fn from_array(values: [M31; N_LANES * N_VERY_PACKED_ELEMS]) -> VeryPackedM31 {
Self::from_fn(|i| {
let start = i * N_LANES;
let end = start + N_LANES;
PackedM31::from_array(values[start..end].try_into().unwrap())
})
}
pub fn to_array(&self) -> [M31; N_LANES * N_VERY_PACKED_ELEMS] {
// Safety: We are transmuting &[A; N_VERY_PACKED_ELEMS] into &[i32; N_LANES *
// N_VERY_PACKED_ELEMS] because we know that A contains [i32; N_LANES] and the
// memory layout is contiguous.
unsafe {
std::slice::from_raw_parts(self.0.as_ptr() as *const M31, N_LANES * N_VERY_PACKED_ELEMS)
.try_into()
.unwrap()
}
}
}
impl VeryPackedCM31 {
pub fn broadcast(value: CM31) -> Self {
Self::from_fn(|_| PackedCM31::broadcast(value))
}
}
impl VeryPackedQM31 {
pub fn broadcast(value: QM31) -> Self {
Self::from_fn(|_| PackedQM31::broadcast(value))
}
pub fn from_very_packed_m31s([a, b, c, d]: [VeryPackedM31; 4]) -> Self {
Self::from_fn(|i| PackedQM31::from_packed_m31s([a.0[i], b.0[i], c.0[i], d.0[i]]))
}
}
impl From<M31> for VeryPackedM31 {
fn from(v: M31) -> Self {
Self::broadcast(v)
}
}
impl From<VeryPackedM31> for VeryPackedQM31 {
fn from(value: VeryPackedM31) -> Self {
VeryPackedQM31::from_very_packed_m31s([
value,
VeryPackedM31::zero(),
VeryPackedM31::zero(),
VeryPackedM31::zero(),
])
}
}
impl From<QM31> for VeryPackedQM31 {
fn from(value: QM31) -> Self {
VeryPackedQM31::broadcast(value)
}
}
trait Scalar {}
impl Scalar for M31 {}
impl Scalar for CM31 {}
impl Scalar for QM31 {}
impl Scalar for PackedM31 {}
impl Scalar for PackedCM31 {}
impl Scalar for PackedQM31 {}
impl<A: Add<B> + Copy, B: Copy, const N: usize> Add<Vectorized<B, N>> for Vectorized<A, N> {
type Output = Vectorized<A::Output, N>;
fn add(self, other: Vectorized<B, N>) -> Self::Output {
Vectorized::from_fn(|i| self.0[i] + other.0[i])
}
}
impl<A: Add<B> + Copy, B: Scalar + Copy, const N: usize> Add<B> for Vectorized<A, N> {
type Output = Vectorized<A::Output, N>;
fn add(self, other: B) -> Self::Output {
Vectorized::from_fn(|i| self.0[i] + other)
}
}
impl<A: Sub<B> + Copy, B: Copy, const N: usize> Sub<Vectorized<B, N>> for Vectorized<A, N> {
type Output = Vectorized<A::Output, N>;
fn sub(self, other: Vectorized<B, N>) -> Self::Output {
Vectorized::from_fn(|i| self.0[i] - other.0[i])
}
}
impl<A: Sub<B> + Copy, B: Scalar + Copy, const N: usize> Sub<B> for Vectorized<A, N> {
type Output = Vectorized<A::Output, N>;
fn sub(self, other: B) -> Self::Output {
Vectorized::from_fn(|i| self.0[i] - other)
}
}
impl<A: Mul<B> + Copy, B: Copy, const N: usize> Mul<Vectorized<B, N>> for Vectorized<A, N> {
type Output = Vectorized<A::Output, N>;
fn mul(self, other: Vectorized<B, N>) -> Self::Output {
Vectorized::from_fn(|i| self.0[i] * other.0[i])
}
}
impl<A: Mul<B> + Copy, B: Scalar + Copy, const N: usize> Mul<B> for Vectorized<A, N> {
type Output = Vectorized<A::Output, N>;
fn mul(self, other: B) -> Self::Output {
Vectorized::from_fn(|i| self.0[i] * other)
}
}
impl<A: AddAssign<B> + Copy, B: Copy, const N: usize> AddAssign<Vectorized<B, N>>
for Vectorized<A, N>
{
fn add_assign(&mut self, other: Vectorized<B, N>) {
for i in 0..N {
self.0[i] += other.0[i];
}
}
}
impl<A: AddAssign<B> + Copy, B: Scalar + Copy, const N: usize> AddAssign<B> for Vectorized<A, N> {
fn add_assign(&mut self, other: B) {
for i in 0..N {
self.0[i] += other;
}
}
}
impl<A: MulAssign<B> + Copy, B: Copy, const N: usize> MulAssign<Vectorized<B, N>>
for Vectorized<A, N>
{
fn mul_assign(&mut self, other: Vectorized<B, N>) {
for i in 0..N {
self.0[i] *= other.0[i];
}
}
}
impl<A: Neg + Copy, const N: usize> Neg for Vectorized<A, N> {
type Output = Vectorized<A::Output, N>;
#[inline(always)]
fn neg(self) -> Self::Output {
Vectorized::from_fn(|i| self.0[i].neg())
}
}
impl<A: Zero + Copy, const N: usize> Zero for Vectorized<A, N> {
fn zero() -> Self {
Vectorized::from_fn(|_| A::zero())
}
fn is_zero(&self) -> bool {
self.0.iter().all(A::is_zero)
}
}
impl<A: One + Copy, const N: usize> One for Vectorized<A, N> {
fn one() -> Self {
Vectorized::from_fn(|_| A::one())
}
}
impl<A: FieldExpOps + Zero, const N: usize> FieldExpOps for Vectorized<A, N> {
fn inverse(&self) -> Self {
Vectorized::from_fn(|i| {
assert!(!self.0[i].is_zero(), "0 has no inverse");
pow2147483645(self.0[i])
})
}
}

View File

@ -0,0 +1,186 @@
use std::iter;
use super::{Channel, ChannelTime};
use crate::core::fields::m31::{BaseField, N_BYTES_FELT, P};
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
use crate::core::fields::IntoSlice;
use crate::core::vcs::blake2_hash::{Blake2sHash, Blake2sHasher};
use crate::core::vcs::blake2s_ref::compress;
pub const BLAKE_BYTES_PER_HASH: usize = 32;
pub const FELTS_PER_HASH: usize = 8;
/// A channel that can be used to draw random elements from a [Blake2sHash] digest.
#[derive(Default, Clone)]
pub struct Blake2sChannel {
digest: Blake2sHash,
pub channel_time: ChannelTime,
}
impl Blake2sChannel {
pub fn digest(&self) -> Blake2sHash {
self.digest
}
pub fn update_digest(&mut self, new_digest: Blake2sHash) {
self.digest = new_digest;
self.channel_time.inc_challenges();
}
/// Generates a uniform random vector of BaseField elements.
fn draw_base_felts(&mut self) -> [BaseField; FELTS_PER_HASH] {
// Repeats hashing with an increasing counter until getting a good result.
// Retry probability for each round is ~ 2^(-28).
loop {
let u32s: [u32; FELTS_PER_HASH] = self
.draw_random_bytes()
.chunks_exact(N_BYTES_FELT) // 4 bytes per u32.
.map(|chunk| u32::from_le_bytes(chunk.try_into().unwrap()))
.collect::<Vec<_>>()
.try_into()
.unwrap();
// Retry if not all the u32 are in the range [0, 2P).
if u32s.iter().all(|x| *x < 2 * P) {
return u32s
.into_iter()
.map(|x| BaseField::reduce(x as u64))
.collect::<Vec<_>>()
.try_into()
.unwrap();
}
}
}
}
impl Channel for Blake2sChannel {
const BYTES_PER_HASH: usize = BLAKE_BYTES_PER_HASH;
fn trailing_zeros(&self) -> u32 {
u128::from_le_bytes(std::array::from_fn(|i| self.digest.0[i])).trailing_zeros()
}
fn mix_felts(&mut self, felts: &[SecureField]) {
let mut hasher = Blake2sHasher::new();
hasher.update(self.digest.as_ref());
hasher.update(IntoSlice::<u8>::into_slice(felts));
self.update_digest(hasher.finalize());
}
fn mix_nonce(&mut self, nonce: u64) {
let digest: [u32; 8] = unsafe { std::mem::transmute(self.digest) };
let mut msg = [0; 16];
msg[0] = nonce as u32;
msg[1] = (nonce >> 32) as u32;
let res = compress(std::array::from_fn(|i| digest[i]), msg, 0, 0, 0, 0);
self.update_digest(unsafe { std::mem::transmute(res) });
}
fn draw_felt(&mut self) -> SecureField {
let felts: [BaseField; FELTS_PER_HASH] = self.draw_base_felts();
SecureField::from_m31_array(felts[..SECURE_EXTENSION_DEGREE].try_into().unwrap())
}
fn draw_felts(&mut self, n_felts: usize) -> Vec<SecureField> {
let mut felts = iter::from_fn(|| Some(self.draw_base_felts())).flatten();
let secure_felts = iter::from_fn(|| {
Some(SecureField::from_m31_array([
felts.next()?,
felts.next()?,
felts.next()?,
felts.next()?,
]))
});
secure_felts.take(n_felts).collect()
}
fn draw_random_bytes(&mut self) -> Vec<u8> {
let mut hash_input = self.digest.as_ref().to_vec();
// Pad the counter to 32 bytes.
let mut padded_counter = [0; BLAKE_BYTES_PER_HASH];
let counter_bytes = self.channel_time.n_sent.to_le_bytes();
padded_counter[0..counter_bytes.len()].copy_from_slice(&counter_bytes);
hash_input.extend_from_slice(&padded_counter);
// TODO(spapini): Are we worried about this drawing hash colliding with mix_digest?
self.channel_time.inc_sent();
Blake2sHasher::hash(&hash_input).into()
}
}
#[cfg(test)]
mod tests {
use std::collections::BTreeSet;
use crate::core::channel::blake2s::Blake2sChannel;
use crate::core::channel::Channel;
use crate::core::fields::qm31::SecureField;
use crate::m31;
#[test]
fn test_channel_time() {
let mut channel = Blake2sChannel::default();
assert_eq!(channel.channel_time.n_challenges, 0);
assert_eq!(channel.channel_time.n_sent, 0);
channel.draw_random_bytes();
assert_eq!(channel.channel_time.n_challenges, 0);
assert_eq!(channel.channel_time.n_sent, 1);
channel.draw_felts(9);
assert_eq!(channel.channel_time.n_challenges, 0);
assert_eq!(channel.channel_time.n_sent, 6);
}
#[test]
fn test_draw_random_bytes() {
let mut channel = Blake2sChannel::default();
let first_random_bytes = channel.draw_random_bytes();
// Assert that next random bytes are different.
assert_ne!(first_random_bytes, channel.draw_random_bytes());
}
#[test]
pub fn test_draw_felt() {
let mut channel = Blake2sChannel::default();
let first_random_felt = channel.draw_felt();
// Assert that next random felt is different.
assert_ne!(first_random_felt, channel.draw_felt());
}
#[test]
pub fn test_draw_felts() {
let mut channel = Blake2sChannel::default();
let mut random_felts = channel.draw_felts(5);
random_felts.extend(channel.draw_felts(4));
// Assert that all the random felts are unique.
assert_eq!(
random_felts.len(),
random_felts.iter().collect::<BTreeSet<_>>().len()
);
}
#[test]
pub fn test_mix_felts() {
let mut channel = Blake2sChannel::default();
let initial_digest = channel.digest;
let felts: Vec<SecureField> = (0..2)
.map(|i| SecureField::from(m31!(i + 1923782)))
.collect();
channel.mix_felts(felts.as_slice());
assert_ne!(initial_digest, channel.digest);
}
}

View File

@ -0,0 +1,57 @@
use super::fields::qm31::SecureField;
use super::vcs::ops::MerkleHasher;
#[cfg(not(target_arch = "wasm32"))]
mod poseidon252;
#[cfg(not(target_arch = "wasm32"))]
pub use poseidon252::Poseidon252Channel;
mod blake2s;
pub use blake2s::Blake2sChannel;
#[cfg(not(target_arch = "wasm32"))]
mod poseidon_bls;
#[cfg(not(target_arch = "wasm32"))]
pub use poseidon_bls::PoseidonBLSChannel;
pub const EXTENSION_FELTS_PER_HASH: usize = 2;
#[derive(Clone, Default)]
pub struct ChannelTime {
pub n_challenges: usize,
n_sent: usize,
}
impl ChannelTime {
fn inc_sent(&mut self) {
self.n_sent += 1;
}
fn inc_challenges(&mut self) {
self.n_challenges += 1;
self.n_sent = 0;
}
}
pub trait Channel: Default + Clone {
const BYTES_PER_HASH: usize;
fn trailing_zeros(&self) -> u32;
// Mix functions.
fn mix_felts(&mut self, felts: &[SecureField]);
fn mix_nonce(&mut self, nonce: u64);
// Draw functions.
fn draw_felt(&mut self) -> SecureField;
/// Generates a uniform random vector of SecureField elements.
fn draw_felts(&mut self, n_felts: usize) -> Vec<SecureField>;
/// Returns a vector of random bytes of length `BYTES_PER_HASH`.
fn draw_random_bytes(&mut self) -> Vec<u8>;
}
pub trait MerkleChannel: Default {
type C: Channel;
type H: MerkleHasher;
fn mix_root(channel: &mut Self::C, root: <Self::H as MerkleHasher>::Hash);
}

View File

@ -0,0 +1,190 @@
use std::iter;
use starknet_crypto::{poseidon_hash, poseidon_hash_many};
use starknet_ff::FieldElement as FieldElement252;
use super::{Channel, ChannelTime};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
pub const BYTES_PER_FELT252: usize = 31;
pub const FELTS_PER_HASH: usize = 8;
/// A channel that can be used to draw random elements from a Poseidon252 hash.
#[derive(Clone, Default)]
pub struct Poseidon252Channel {
digest: FieldElement252,
pub channel_time: ChannelTime,
}
impl Poseidon252Channel {
pub fn digest(&self) -> FieldElement252 {
self.digest
}
pub fn update_digest(&mut self, new_digest: FieldElement252) {
self.digest = new_digest;
self.channel_time.inc_challenges();
}
fn draw_felt252(&mut self) -> FieldElement252 {
let res = poseidon_hash(self.digest, self.channel_time.n_sent.into());
self.channel_time.inc_sent();
res
}
// TODO(spapini): Understand if we really need uniformity here.
/// Generates a close-to uniform random vector of BaseField elements.
fn draw_base_felts(&mut self) -> [BaseField; 8] {
let shift = (1u64 << 31).into();
let mut cur = self.draw_felt252();
let u32s: [u32; 8] = std::array::from_fn(|_| {
let next = cur.floor_div(shift);
let res = cur - next * shift;
cur = next;
res.try_into().unwrap()
});
u32s.into_iter()
.map(|x| BaseField::reduce(x as u64))
.collect::<Vec<_>>()
.try_into()
.unwrap()
}
}
impl Channel for Poseidon252Channel {
const BYTES_PER_HASH: usize = BYTES_PER_FELT252;
fn trailing_zeros(&self) -> u32 {
let bytes = self.digest.to_bytes_be();
u128::from_le_bytes(std::array::from_fn(|i| bytes[i])).trailing_zeros()
}
// TODO(spapini): Optimize.
fn mix_felts(&mut self, felts: &[SecureField]) {
let shift = (1u64 << 31).into();
let mut res = Vec::with_capacity(felts.len() / 2 + 2);
res.push(self.digest);
for chunk in felts.chunks(2) {
res.push(
chunk
.iter()
.flat_map(|x| x.to_m31_array())
.fold(FieldElement252::default(), |cur, y| {
cur * shift + y.0.into()
}),
);
}
// TODO(spapini): do we need length padding?
self.update_digest(poseidon_hash_many(&res));
}
fn mix_nonce(&mut self, nonce: u64) {
self.update_digest(poseidon_hash(self.digest, nonce.into()));
}
fn draw_felt(&mut self) -> SecureField {
let felts: [BaseField; FELTS_PER_HASH] = self.draw_base_felts();
SecureField::from_m31_array(felts[..SECURE_EXTENSION_DEGREE].try_into().unwrap())
}
fn draw_felts(&mut self, n_felts: usize) -> Vec<SecureField> {
let mut felts = iter::from_fn(|| Some(self.draw_base_felts())).flatten();
let secure_felts = iter::from_fn(|| {
Some(SecureField::from_m31_array([
felts.next()?,
felts.next()?,
felts.next()?,
felts.next()?,
]))
});
secure_felts.take(n_felts).collect()
}
fn draw_random_bytes(&mut self) -> Vec<u8> {
let shift = (1u64 << 8).into();
let mut cur = self.draw_felt252();
let bytes: [u8; 31] = std::array::from_fn(|_| {
let next = cur.floor_div(shift);
let res = cur - next * shift;
cur = next;
res.try_into().unwrap()
});
bytes.to_vec()
}
}
#[cfg(test)]
mod tests {
use std::collections::BTreeSet;
use crate::core::channel::poseidon252::Poseidon252Channel;
use crate::core::channel::Channel;
use crate::core::fields::qm31::SecureField;
use crate::m31;
#[test]
fn test_channel_time() {
let mut channel = Poseidon252Channel::default();
assert_eq!(channel.channel_time.n_challenges, 0);
assert_eq!(channel.channel_time.n_sent, 0);
channel.draw_random_bytes();
assert_eq!(channel.channel_time.n_challenges, 0);
assert_eq!(channel.channel_time.n_sent, 1);
channel.draw_felts(9);
assert_eq!(channel.channel_time.n_challenges, 0);
assert_eq!(channel.channel_time.n_sent, 6);
}
#[test]
fn test_draw_random_bytes() {
let mut channel = Poseidon252Channel::default();
let first_random_bytes = channel.draw_random_bytes();
// Assert that next random bytes are different.
assert_ne!(first_random_bytes, channel.draw_random_bytes());
}
#[test]
pub fn test_draw_felt() {
let mut channel = Poseidon252Channel::default();
let first_random_felt = channel.draw_felt();
// Assert that next random felt is different.
assert_ne!(first_random_felt, channel.draw_felt());
}
#[test]
pub fn test_draw_felts() {
let mut channel = Poseidon252Channel::default();
let mut random_felts = channel.draw_felts(5);
random_felts.extend(channel.draw_felts(4));
// Assert that all the random felts are unique.
assert_eq!(
random_felts.len(),
random_felts.iter().collect::<BTreeSet<_>>().len()
);
}
#[test]
pub fn test_mix_felts() {
let mut channel = Poseidon252Channel::default();
let initial_digest = channel.digest;
let felts: Vec<SecureField> = (0..2)
.map(|i| SecureField::from(m31!(i + 1923782)))
.collect();
channel.mix_felts(felts.as_slice());
assert_ne!(initial_digest, channel.digest);
}
}

View File

@ -0,0 +1,590 @@
use std::iter;
use ark_bls12_381::Fr as BlsFr;
use ark_ff::{BigInteger, Field, PrimeField};
use crypto_bigint::{Encoding, NonZero, U256};
use super::{Channel, ChannelTime};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
pub const BYTES_PER_FELT252: usize = 32;
pub const FELTS_PER_HASH: usize = 8;
//Optimize constant to be real constants (no conversion) and merge duplicated code in VCS poseidono
fn poseidon_comp_consts(idx: usize) -> BlsFr {
match idx {
0 => BlsFr::from_be_bytes_mod_order(&[
111, 0, 122, 85, 17, 86, 179, 164, 73, 228, 73, 54, 183, 192, 147, 100, 74, 14, 211,
63, 51, 234, 204, 198, 40, 233, 66, 232, 54, 193, 168, 117,
]),
1 => BlsFr::from_be_bytes_mod_order(&[
54, 13, 116, 112, 97, 30, 71, 61, 53, 63, 98, 143, 118, 209, 16, 243, 78, 113, 22, 47,
49, 0, 59, 112, 87, 83, 140, 37, 150, 66, 99, 3,
]),
2 => BlsFr::from_be_bytes_mod_order(&[
75, 95, 236, 58, 160, 115, 223, 68, 1, 144, 145, 240, 7, 164, 76, 169, 150, 72, 73,
101, 247, 3, 109, 206, 62, 157, 9, 119, 237, 205, 192, 246,
]),
3 => BlsFr::from_be_bytes_mod_order(&[
103, 207, 24, 104, 175, 99, 150, 192, 184, 76, 206, 113, 94, 83, 159, 132, 158, 6, 205,
28, 56, 58, 197, 176, 97, 0, 199, 107, 204, 151, 58, 17,
]),
4 => BlsFr::from_be_bytes_mod_order(&[
85, 93, 180, 209, 220, 237, 129, 159, 93, 61, 231, 15, 222, 131, 241, 199, 211, 232,
201, 137, 104, 229, 22, 162, 58, 119, 26, 92, 156, 130, 87, 170,
]),
5 => BlsFr::from_be_bytes_mod_order(&[
43, 171, 148, 215, 174, 34, 45, 19, 93, 195, 198, 197, 254, 191, 170, 49, 73, 8, 172,
47, 18, 235, 224, 111, 189, 183, 66, 19, 191, 99, 24, 139,
]),
6 => BlsFr::from_be_bytes_mod_order(&[
102, 244, 75, 229, 41, 102, 130, 196, 250, 120, 130, 121, 157, 109, 208, 73, 182, 215,
210, 201, 80, 204, 249, 140, 242, 229, 13, 109, 30, 187, 119, 194,
]),
7 => BlsFr::from_be_bytes_mod_order(&[
21, 12, 147, 254, 246, 82, 251, 28, 43, 240, 62, 26, 41, 170, 135, 31, 239, 119, 231,
215, 54, 118, 108, 93, 9, 57, 217, 39, 83, 204, 93, 200,
]),
8 => BlsFr::from_be_bytes_mod_order(&[
50, 112, 102, 30, 104, 146, 139, 58, 149, 93, 85, 219, 86, 220, 87, 193, 3, 204, 10,
96, 20, 30, 137, 78, 20, 37, 157, 206, 83, 119, 130, 178,
]),
9 => BlsFr::from_be_bytes_mod_order(&[
7, 63, 17, 111, 4, 18, 46, 37, 160, 183, 175, 228, 226, 5, 114, 153, 180, 7, 195, 112,
242, 181, 161, 204, 206, 159, 185, 255, 195, 69, 175, 179,
]),
10 => BlsFr::from_be_bytes_mod_order(&[
64, 159, 218, 34, 85, 140, 254, 77, 61, 216, 220, 226, 79, 105, 231, 111, 140, 42, 174,
177, 221, 15, 9, 214, 94, 101, 76, 113, 243, 42, 162, 63,
]),
11 => BlsFr::from_be_bytes_mod_order(&[
42, 50, 236, 92, 78, 229, 177, 131, 122, 255, 208, 156, 31, 83, 245, 253, 85, 201, 205,
32, 97, 174, 147, 202, 142, 186, 215, 111, 199, 21, 84, 216,
]),
12 => BlsFr::from_be_bytes_mod_order(&[
88, 72, 235, 235, 89, 35, 233, 37, 85, 183, 18, 79, 255, 186, 93, 107, 213, 113, 198,
249, 132, 25, 94, 185, 207, 211, 163, 232, 235, 85, 177, 212,
]),
13 => BlsFr::from_be_bytes_mod_order(&[
39, 3, 38, 238, 3, 157, 241, 158, 101, 30, 44, 252, 116, 6, 40, 202, 99, 77, 36, 252,
110, 37, 89, 242, 45, 140, 203, 226, 146, 239, 238, 173,
]),
14 => BlsFr::from_be_bytes_mod_order(&[
39, 198, 100, 42, 198, 51, 188, 102, 220, 16, 15, 231, 252, 250, 84, 145, 138, 248,
149, 188, 224, 18, 241, 130, 160, 104, 252, 55, 193, 130, 226, 116,
]),
15 => BlsFr::from_be_bytes_mod_order(&[
27, 223, 216, 176, 20, 1, 199, 10, 210, 127, 87, 57, 105, 137, 18, 157, 113, 14, 31,
182, 171, 151, 106, 69, 156, 161, 134, 130, 226, 109, 127, 249,
]),
16 => BlsFr::from_be_bytes_mod_order(&[
73, 27, 155, 166, 152, 59, 207, 159, 5, 254, 71, 148, 173, 180, 74, 48, 135, 155, 248,
40, 150, 98, 225, 245, 125, 144, 246, 114, 65, 78, 138, 74,
]),
17 => BlsFr::from_be_bytes_mod_order(&[
22, 42, 20, 198, 47, 154, 137, 184, 20, 185, 214, 169, 200, 77, 214, 120, 244, 246,
251, 63, 144, 84, 211, 115, 200, 50, 216, 36, 38, 26, 53, 234,
]),
18 => BlsFr::from_be_bytes_mod_order(&[
45, 25, 62, 15, 118, 222, 88, 107, 42, 246, 247, 158, 49, 39, 254, 234, 172, 10, 31,
199, 30, 44, 240, 192, 247, 152, 36, 102, 123, 91, 107, 236,
]),
19 => BlsFr::from_be_bytes_mod_order(&[
70, 239, 216, 169, 162, 98, 214, 216, 253, 201, 202, 92, 4, 176, 152, 47, 36, 221, 204,
110, 152, 99, 136, 90, 106, 115, 42, 57, 6, 160, 123, 149,
]),
20 => BlsFr::from_be_bytes_mod_order(&[
80, 151, 23, 224, 194, 0, 227, 201, 45, 141, 202, 41, 115, 179, 219, 69, 240, 120, 130,
148, 53, 26, 208, 122, 231, 92, 187, 120, 6, 147, 167, 152,
]),
21 => BlsFr::from_be_bytes_mod_order(&[
114, 153, 178, 132, 100, 168, 201, 79, 185, 212, 223, 97, 56, 15, 57, 192, 220, 169,
194, 192, 20, 17, 135, 137, 226, 39, 37, 40, 32, 240, 27, 252,
]),
22 => BlsFr::from_be_bytes_mod_order(&[
4, 76, 163, 204, 74, 133, 215, 59, 129, 105, 110, 241, 16, 78, 103, 79, 79, 239, 248,
41, 132, 153, 15, 248, 93, 11, 245, 141, 200, 164, 170, 148,
]),
23 => BlsFr::from_be_bytes_mod_order(&[
28, 186, 242, 179, 113, 218, 198, 168, 29, 4, 83, 65, 109, 62, 35, 92, 184, 217, 226,
212, 243, 20, 244, 111, 97, 152, 120, 95, 12, 214, 185, 175,
]),
24 => BlsFr::from_be_bytes_mod_order(&[
29, 91, 39, 119, 105, 44, 32, 91, 14, 108, 73, 208, 97, 182, 181, 244, 41, 60, 74, 176,
56, 253, 187, 220, 52, 62, 7, 97, 15, 63, 237, 229,
]),
25 => BlsFr::from_be_bytes_mod_order(&[
86, 174, 124, 122, 82, 147, 189, 194, 62, 133, 225, 105, 140, 129, 199, 127, 138, 216,
140, 75, 51, 165, 120, 4, 55, 173, 4, 124, 110, 219, 89, 186,
]),
26 => BlsFr::from_be_bytes_mod_order(&[
46, 155, 219, 186, 61, 211, 75, 255, 170, 48, 83, 91, 221, 116, 154, 126, 6, 169, 173,
176, 193, 230, 249, 98, 246, 14, 151, 27, 141, 115, 176, 79,
]),
27 => BlsFr::from_be_bytes_mod_order(&[
45, 225, 24, 134, 177, 128, 17, 202, 139, 213, 186, 227, 105, 105, 41, 159, 222, 64,
251, 226, 109, 4, 123, 5, 3, 90, 19, 102, 31, 34, 65, 139,
]),
28 => BlsFr::from_be_bytes_mod_order(&[
46, 7, 222, 23, 128, 184, 167, 13, 13, 91, 74, 63, 24, 65, 220, 216, 42, 185, 57, 92,
68, 155, 233, 71, 188, 153, 136, 132, 186, 150, 167, 33,
]),
29 => BlsFr::from_be_bytes_mod_order(&[
15, 105, 241, 133, 77, 32, 202, 12, 187, 219, 99, 219, 213, 45, 173, 22, 37, 4, 64,
169, 157, 107, 138, 243, 130, 94, 76, 43, 183, 73, 37, 202,
]),
30 => BlsFr::from_be_bytes_mod_order(&[
93, 201, 135, 49, 142, 110, 89, 193, 175, 184, 123, 101, 93, 213, 140, 193, 210, 46,
81, 58, 5, 131, 140, 212, 88, 93, 4, 177, 53, 185, 87, 202,
]),
31 => BlsFr::from_be_bytes_mod_order(&[
72, 183, 37, 117, 133, 113, 201, 223, 108, 1, 220, 99, 154, 133, 240, 114, 151, 105,
107, 27, 182, 120, 99, 58, 41, 220, 145, 222, 149, 239, 83, 246,
]),
32 => BlsFr::from_be_bytes_mod_order(&[
94, 86, 94, 8, 192, 130, 16, 153, 37, 107, 86, 73, 14, 174, 225, 213, 115, 175, 209,
11, 182, 209, 125, 19, 202, 78, 92, 97, 27, 42, 55, 24,
]),
33 => BlsFr::from_be_bytes_mod_order(&[
46, 177, 178, 84, 23, 254, 23, 103, 13, 19, 93, 198, 57, 251, 9, 164, 108, 229, 17, 53,
7, 249, 109, 233, 129, 108, 5, 148, 34, 220, 112, 94,
]),
34 => BlsFr::from_be_bytes_mod_order(&[
17, 92, 208, 160, 100, 60, 251, 152, 140, 36, 203, 68, 195, 250, 180, 138, 255, 54,
198, 97, 210, 108, 196, 45, 184, 177, 189, 244, 149, 59, 216, 44,
]),
35 => BlsFr::from_be_bytes_mod_order(&[
38, 202, 41, 63, 123, 44, 70, 45, 6, 109, 115, 120, 185, 153, 134, 139, 187, 87, 221,
241, 78, 15, 149, 138, 222, 128, 22, 18, 49, 29, 4, 205,
]),
36 => BlsFr::from_be_bytes_mod_order(&[
65, 71, 64, 13, 142, 26, 172, 207, 49, 26, 107, 91, 118, 32, 17, 171, 62, 69, 50, 110,
77, 75, 157, 226, 105, 146, 129, 107, 153, 197, 40, 172,
]),
37 => BlsFr::from_be_bytes_mod_order(&[
107, 13, 183, 220, 204, 75, 161, 178, 104, 246, 189, 204, 77, 55, 40, 72, 212, 167, 41,
118, 194, 104, 234, 48, 81, 154, 47, 115, 230, 219, 77, 85,
]),
38 => BlsFr::from_be_bytes_mod_order(&[
23, 191, 27, 147, 196, 199, 224, 26, 42, 131, 10, 161, 98, 65, 44, 217, 15, 22, 11,
249, 247, 30, 150, 127, 245, 32, 157, 20, 178, 72, 32, 202,
]),
39 => BlsFr::from_be_bytes_mod_order(&[
75, 67, 28, 217, 239, 237, 188, 148, 207, 30, 202, 111, 158, 156, 24, 57, 208, 230,
106, 139, 255, 168, 200, 70, 76, 172, 129, 163, 157, 60, 248, 241,
]),
40 => BlsFr::from_be_bytes_mod_order(&[
53, 180, 26, 122, 196, 243, 197, 113, 162, 79, 132, 86, 54, 156, 133, 223, 224, 60, 3,
84, 189, 140, 253, 56, 5, 200, 111, 46, 125, 194, 147, 197,
]),
41 => BlsFr::from_be_bytes_mod_order(&[
59, 20, 128, 8, 5, 35, 196, 57, 67, 89, 39, 153, 72, 73, 190, 169, 100, 225, 77, 59,
235, 45, 221, 222, 114, 172, 21, 106, 244, 53, 208, 158,
]),
42 => BlsFr::from_be_bytes_mod_order(&[
44, 198, 129, 0, 49, 220, 27, 13, 73, 80, 133, 109, 201, 7, 213, 117, 8, 226, 134, 68,
42, 45, 62, 178, 39, 22, 24, 216, 116, 177, 76, 109,
]),
43 => BlsFr::from_be_bytes_mod_order(&[
111, 65, 65, 200, 64, 28, 90, 57, 91, 166, 121, 14, 253, 113, 199, 12, 4, 175, 234, 6,
195, 201, 40, 38, 188, 171, 221, 92, 181, 71, 125, 81,
]),
44 => BlsFr::from_be_bytes_mod_order(&[
37, 189, 187, 237, 161, 189, 232, 193, 5, 150, 24, 226, 175, 210, 239, 153, 158, 81,
122, 169, 59, 120, 52, 29, 145, 243, 24, 192, 159, 12, 181, 102,
]),
45 => BlsFr::from_be_bytes_mod_order(&[
57, 42, 74, 135, 88, 224, 110, 232, 185, 95, 51, 194, 93, 222, 138, 192, 42, 94, 208,
162, 123, 97, 146, 108, 198, 49, 52, 135, 7, 63, 127, 123,
]),
46 => BlsFr::from_be_bytes_mod_order(&[
39, 42, 85, 135, 138, 8, 68, 43, 154, 166, 17, 31, 77, 224, 9, 72, 94, 106, 111, 209,
93, 184, 147, 101, 231, 187, 206, 240, 46, 181, 134, 108,
]),
47 => BlsFr::from_be_bytes_mod_order(&[
99, 30, 193, 214, 210, 141, 217, 232, 36, 238, 137, 163, 7, 48, 174, 247, 171, 70, 58,
207, 201, 209, 132, 179, 85, 170, 5, 253, 105, 56, 234, 181,
]),
48 => BlsFr::from_be_bytes_mod_order(&[
78, 182, 253, 161, 15, 208, 251, 222, 2, 199, 68, 155, 251, 221, 195, 91, 205, 130, 37,
231, 229, 195, 131, 58, 8, 24, 161, 0, 64, 157, 198, 242,
]),
49 => BlsFr::from_be_bytes_mod_order(&[
45, 91, 48, 139, 12, 240, 44, 223, 239, 161, 60, 78, 96, 226, 98, 57, 166, 235, 186, 1,
22, 148, 221, 18, 155, 146, 91, 60, 91, 33, 224, 226,
]),
50 => BlsFr::from_be_bytes_mod_order(&[
22, 84, 159, 198, 175, 47, 59, 114, 221, 93, 41, 61, 114, 226, 229, 242, 68, 223, 244,
47, 24, 180, 108, 86, 239, 56, 197, 124, 49, 22, 115, 172,
]),
51 => BlsFr::from_be_bytes_mod_order(&[
66, 51, 38, 119, 255, 53, 156, 94, 141, 184, 54, 217, 245, 251, 84, 130, 46, 57, 189,
94, 34, 52, 11, 185, 186, 151, 91, 161, 169, 43, 227, 130,
]),
52 => BlsFr::from_be_bytes_mod_order(&[
73, 215, 210, 192, 180, 73, 229, 23, 155, 197, 204, 195, 180, 76, 96, 117, 217, 132,
155, 86, 16, 70, 95, 9, 234, 114, 93, 220, 151, 114, 58, 148,
]),
53 => BlsFr::from_be_bytes_mod_order(&[
100, 194, 15, 185, 13, 122, 0, 56, 49, 117, 124, 196, 198, 34, 111, 110, 73, 133, 252,
158, 203, 65, 107, 159, 104, 76, 160, 53, 29, 150, 121, 4,
]),
54 => BlsFr::from_be_bytes_mod_order(&[
89, 207, 244, 13, 232, 59, 82, 180, 27, 196, 67, 215, 151, 149, 16, 215, 113, 201, 64,
185, 117, 140, 168, 32, 254, 115, 181, 200, 213, 88, 9, 52,
]),
55 => BlsFr::from_be_bytes_mod_order(&[
83, 219, 39, 49, 115, 12, 57, 176, 78, 221, 135, 95, 227, 183, 200, 130, 128, 130, 133,
205, 188, 98, 29, 122, 244, 248, 13, 213, 62, 187, 113, 176,
]),
56 => BlsFr::from_be_bytes_mod_order(&[
27, 16, 187, 122, 130, 175, 206, 57, 250, 105, 195, 162, 173, 82, 247, 109, 118, 57,
130, 101, 52, 66, 3, 17, 155, 113, 38, 217, 180, 104, 96, 223,
]),
57 => BlsFr::from_be_bytes_mod_order(&[
86, 27, 96, 18, 214, 102, 191, 225, 121, 196, 221, 127, 132, 205, 209, 83, 21, 150,
211, 170, 199, 197, 112, 12, 235, 49, 159, 145, 4, 106, 99, 201,
]),
58 => BlsFr::from_be_bytes_mod_order(&[
15, 30, 117, 5, 235, 217, 29, 47, 199, 156, 45, 247, 220, 152, 163, 190, 209, 179, 105,
104, 186, 4, 5, 192, 144, 210, 127, 106, 0, 183, 223, 200,
]),
59 => BlsFr::from_be_bytes_mod_order(&[
47, 49, 63, 175, 13, 63, 97, 135, 83, 122, 116, 151, 163, 180, 63, 70, 121, 127, 214,
227, 241, 142, 177, 202, 255, 69, 119, 86, 184, 25, 187, 32,
]),
60 => BlsFr::from_be_bytes_mod_order(&[
58, 92, 187, 109, 228, 80, 180, 129, 250, 60, 166, 28, 14, 209, 91, 197, 92, 173, 17,
235, 240, 247, 206, 184, 240, 188, 62, 115, 46, 203, 38, 246,
]),
61 => BlsFr::from_be_bytes_mod_order(&[
104, 29, 147, 65, 27, 248, 206, 99, 246, 113, 106, 239, 189, 14, 36, 80, 100, 84, 192,
52, 142, 227, 143, 171, 235, 38, 71, 2, 113, 76, 207, 148,
]),
62 => BlsFr::from_be_bytes_mod_order(&[
81, 120, 233, 64, 245, 0, 4, 49, 38, 70, 180, 54, 114, 127, 14, 128, 167, 184, 242,
233, 238, 31, 220, 103, 124, 72, 49, 167, 103, 39, 119, 251,
]),
63 => BlsFr::from_be_bytes_mod_order(&[
61, 171, 84, 188, 155, 239, 104, 141, 217, 32, 134, 226, 83, 180, 57, 214, 81, 186,
166, 226, 15, 137, 43, 98, 134, 85, 39, 203, 202, 145, 89, 130,
]),
64 => BlsFr::from_be_bytes_mod_order(&[
75, 60, 231, 83, 17, 33, 143, 154, 233, 5, 248, 78, 170, 91, 43, 56, 24, 68, 139, 191,
57, 114, 225, 170, 214, 157, 227, 33, 0, 144, 21, 208,
]),
65 => BlsFr::from_be_bytes_mod_order(&[
6, 219, 251, 66, 185, 121, 136, 77, 226, 128, 211, 22, 112, 18, 63, 116, 76, 36, 179,
59, 65, 15, 239, 212, 54, 128, 69, 172, 242, 183, 26, 227,
]),
66 => BlsFr::from_be_bytes_mod_order(&[
6, 141, 107, 70, 8, 170, 232, 16, 198, 240, 57, 234, 25, 115, 166, 62, 184, 210, 222,
114, 227, 210, 201, 236, 167, 252, 50, 210, 47, 24, 185, 211,
]),
67 => BlsFr::from_be_bytes_mod_order(&[
76, 92, 37, 69, 137, 169, 42, 54, 8, 74, 87, 211, 177, 217, 100, 39, 138, 204, 126, 79,
232, 246, 159, 41, 85, 149, 79, 39, 167, 156, 235, 239,
]),
68 => BlsFr::from_be_bytes_mod_order(&[
108, 186, 197, 225, 112, 9, 132, 235, 195, 45, 161, 91, 75, 185, 104, 63, 170, 186,
181, 95, 103, 204, 196, 247, 29, 149, 96, 179, 71, 90, 119, 235,
]),
69 => BlsFr::from_be_bytes_mod_order(&[
70, 3, 196, 3, 187, 250, 154, 23, 115, 138, 92, 98, 120, 234, 171, 28, 55, 236, 48,
176, 115, 122, 162, 64, 159, 196, 137, 128, 105, 235, 152, 60,
]),
70 => BlsFr::from_be_bytes_mod_order(&[
104, 148, 231, 226, 43, 44, 29, 92, 112, 167, 18, 166, 52, 90, 230, 177, 146, 169, 200,
51, 169, 35, 76, 49, 197, 106, 172, 209, 107, 194, 241, 0,
]),
71 => BlsFr::from_be_bytes_mod_order(&[
91, 226, 203, 188, 68, 5, 58, 208, 138, 250, 77, 30, 171, 199, 243, 210, 49, 238, 167,
153, 185, 63, 34, 110, 144, 91, 125, 77, 101, 197, 142, 187,
]),
72 => BlsFr::from_be_bytes_mod_order(&[
88, 229, 95, 40, 123, 69, 58, 152, 8, 98, 74, 140, 42, 53, 61, 82, 141, 160, 247, 231,
19, 165, 198, 208, 215, 113, 30, 71, 6, 63, 166, 17,
]),
73 => BlsFr::from_be_bytes_mod_order(&[
54, 110, 191, 175, 163, 173, 56, 28, 14, 226, 88, 201, 184, 253, 252, 205, 184, 104,
167, 215, 225, 241, 246, 154, 43, 93, 252, 197, 87, 37, 85, 223,
]),
74 => BlsFr::from_be_bytes_mod_order(&[
69, 118, 106, 183, 40, 150, 140, 100, 47, 144, 217, 124, 207, 85, 4, 221, 193, 5, 24,
168, 25, 235, 188, 196, 208, 156, 63, 93, 120, 77, 103, 206,
]),
75 => BlsFr::from_be_bytes_mod_order(&[
57, 103, 143, 101, 81, 47, 30, 228, 4, 219, 48, 36, 244, 29, 63, 86, 126, 246, 109,
137, 208, 68, 208, 34, 230, 188, 34, 158, 149, 188, 118, 177,
]),
76 => BlsFr::from_be_bytes_mod_order(&[
70, 58, 237, 29, 47, 31, 149, 94, 48, 120, 190, 91, 247, 191, 196, 111, 192, 235, 140,
81, 85, 25, 6, 168, 134, 143, 24, 255, 174, 48, 207, 79,
]),
77 => BlsFr::from_be_bytes_mod_order(&[
33, 102, 143, 1, 106, 128, 99, 192, 213, 139, 119, 80, 163, 188, 47, 225, 207, 130,
194, 95, 153, 220, 1, 164, 229, 52, 200, 143, 229, 61, 133, 254,
]),
78 => BlsFr::from_be_bytes_mod_order(&[
57, 208, 9, 148, 168, 165, 4, 106, 27, 199, 73, 54, 62, 152, 167, 104, 227, 77, 234,
86, 67, 159, 225, 149, 75, 239, 66, 155, 197, 51, 22, 8,
]),
79 => BlsFr::from_be_bytes_mod_order(&[
77, 127, 93, 205, 120, 236, 233, 169, 51, 152, 77, 227, 44, 11, 72, 250, 194, 187, 169,
31, 38, 25, 150, 184, 233, 209, 2, 23, 115, 189, 7, 204,
]),
_ => BlsFr::ZERO,
}
}
/// A channel that can be used to draw random elements from a PoseidonBLS hash.
#[derive(Clone, Default)]
pub struct PoseidonBLSChannel {
digest: BlsFr,
pub channel_time: ChannelTime,
}
pub fn poseidon_hash_bls(x: BlsFr, y: BlsFr) -> BlsFr {
let mut state = [x, y, BlsFr::ZERO];
poseidon_permute_comp_bls(&mut state);
state[0] + x
}
pub fn poseidon_permute_comp_bls(state: &mut [BlsFr; 3]) {
let mut idx = 0;
mix(state);
// Full rounds
for _ in 0..4 {
round_comp(state, idx, true);
idx += 3;
}
// Partial rounds
for _ in 0..56 {
round_comp(state, idx, false);
idx += 1;
}
// Full rounds
for _ in 0..4 {
round_comp(state, idx, true);
idx += 3;
}
}
#[inline]
fn round_comp(state: &mut [BlsFr; 3], idx: usize, full: bool) {
if full {
state[0] += poseidon_comp_consts(idx);
state[1] += poseidon_comp_consts(idx + 1);
state[2] += poseidon_comp_consts(idx + 2);
// Optimize multiplication
state[0] = state[0] * state[0] * state[0] * state[0] * state[0];
state[1] = state[1] * state[1] * state[1] * state[1] * state[1];
state[2] = state[2] * state[2] * state[2] * state[2] * state[2];
} else {
state[0] += poseidon_comp_consts(idx);
state[2] = state[2] * state[2] * state[2] * state[2] * state[2];
}
mix(state);
}
#[inline(always)]
fn mix(state: &mut [BlsFr; 3]) {
state[0] = state[0] + state[1] + state[2];
state[1] = state[0] + state[1];
state[2] = state[0] + state[2];
}
pub fn poseidon_hash_many_bls(msgs: &[BlsFr]) -> BlsFr {
let mut state = [BlsFr::ZERO, BlsFr::ZERO, BlsFr::ZERO];
let mut iter = msgs.chunks_exact(2);
for msg in iter.by_ref() {
state[0] += msg[0];
state[1] += msg[1];
poseidon_permute_comp_bls(&mut state);
}
let r = iter.remainder();
if r.len() == 1 {
state[0] += r[0];
}
state[r.len()] += BlsFr::ONE;
poseidon_permute_comp_bls(&mut state);
state[0]
}
impl PoseidonBLSChannel {
pub fn digest(&self) -> BlsFr {
self.digest
}
pub fn update_digest(&mut self, new_digest: BlsFr) {
self.digest = new_digest;
self.channel_time.inc_challenges();
}
fn draw_felt252(&mut self) -> BlsFr {
let res = poseidon_hash_bls(self.digest, BlsFr::from(self.channel_time.n_sent as u64));
self.channel_time.inc_sent();
res
}
// TODO(spapini): Understand if we really need uniformity here.
/// Generates a close-to uniform random vector of BaseField elements.
fn draw_base_felts(&mut self) -> [BaseField; 8] {
let shift = NonZero::new(U256::from_u64(1u64 << 31)).unwrap();
let mut cur = self.draw_felt252();
let u32s: [u32; 8usize] = std::array::from_fn(|_| {
let (quotient, reminder) =
U256::from_be_slice(&cur.into_bigint().to_bytes_be()).div_rem(&shift);
cur = BlsFr::from_be_bytes_mod_order(&quotient.to_be_bytes());
u32::from_str_radix(&reminder.to_string(),16).unwrap()
});
u32s.into_iter()
.map(|x| BaseField::reduce(x as u64))
.collect::<Vec<_>>()
.try_into()
.unwrap()
}
}
impl Channel for PoseidonBLSChannel {
const BYTES_PER_HASH: usize = BYTES_PER_FELT252;
fn trailing_zeros(&self) -> u32 {
let bytes = self.digest.into_bigint().to_bytes_be();
u128::from_le_bytes(std::array::from_fn(|i| bytes[i])).trailing_zeros()
}
// TODO(spapini): Optimize.
fn mix_felts(&mut self, felts: &[SecureField]) {
let shift = BlsFr::from(1u64 << 31);
let mut res = Vec::with_capacity(felts.len() / 2 + 2);
res.push(self.digest);
for chunk in felts.chunks(2) {
res.push(
chunk
.iter()
.flat_map(|x| x.to_m31_array())
.fold(BlsFr::default(), |cur, y| {
cur * shift + BlsFr::from_be_bytes_mod_order(&y.0.to_be_bytes())
}),
);
}
// TODO(spapini): do we need length padding?
self.update_digest(poseidon_hash_many_bls(&res));
}
fn mix_nonce(&mut self, nonce: u64) {
self.update_digest(poseidon_hash_bls(self.digest, nonce.into()));
}
fn draw_felt(&mut self) -> SecureField {
let felts: [BaseField; FELTS_PER_HASH] = self.draw_base_felts();
SecureField::from_m31_array(felts[..SECURE_EXTENSION_DEGREE].try_into().unwrap())
}
fn draw_felts(&mut self, n_felts: usize) -> Vec<SecureField> {
let mut felts = iter::from_fn(|| Some(self.draw_base_felts())).flatten();
let secure_felts = iter::from_fn(|| {
Some(SecureField::from_m31_array([
felts.next()?,
felts.next()?,
felts.next()?,
felts.next()?,
]))
});
secure_felts.take(n_felts).collect()
}
fn draw_random_bytes(&mut self) -> Vec<u8> {
let shift = NonZero::new(U256::from_u64(1u64 << 8)).unwrap();
let mut cur = self.draw_felt252();
let bytes: [u8; 31] = std::array::from_fn(|_| {
let (quotient, reminder) =
U256::from_be_slice(&cur.into_bigint().to_bytes_be()).div_rem(&shift);
cur = BlsFr::from_be_bytes_mod_order(&quotient.to_be_bytes());
u8::from_str_radix(&reminder.to_string(),16).unwrap()
});
bytes.to_vec()
}
}
#[cfg(test)]
mod tests {
use std::collections::BTreeSet;
use crate::core::channel::poseidon_bls::PoseidonBLSChannel;
use crate::core::channel::Channel;
use crate::core::fields::qm31::SecureField;
use crate::m31;
#[test]
fn test_channel_time() {
let mut channel = PoseidonBLSChannel::default();
assert_eq!(channel.channel_time.n_challenges, 0);
assert_eq!(channel.channel_time.n_sent, 0);
channel.draw_random_bytes();
assert_eq!(channel.channel_time.n_challenges, 0);
assert_eq!(channel.channel_time.n_sent, 1);
channel.draw_felts(9);
assert_eq!(channel.channel_time.n_challenges, 0);
assert_eq!(channel.channel_time.n_sent, 6);
}
#[test]
fn test_draw_random_bytes() {
let mut channel = PoseidonBLSChannel::default();
let first_random_bytes = channel.draw_random_bytes();
// Assert that next random bytes are different.
assert_ne!(first_random_bytes, channel.draw_random_bytes());
}
#[test]
pub fn test_draw_felt() {
let mut channel = PoseidonBLSChannel::default();
let first_random_felt = channel.draw_felt();
// Assert that next random felt is different.
assert_ne!(first_random_felt, channel.draw_felt());
}
#[test]
pub fn test_draw_felts() {
let mut channel = PoseidonBLSChannel::default();
let mut random_felts = channel.draw_felts(5);
random_felts.extend(channel.draw_felts(4));
// Assert that all the random felts are unique.
assert_eq!(
random_felts.len(),
random_felts.iter().collect::<BTreeSet<_>>().len()
);
}
#[test]
pub fn test_mix_felts() {
let mut channel = PoseidonBLSChannel::default();
let initial_digest = channel.digest;
let felts: Vec<SecureField> = (0..2)
.map(|i| SecureField::from(m31!(i + 1923782)))
.collect();
channel.mix_felts(felts.as_slice());
assert_ne!(initial_digest, channel.digest);
}
}

View File

@ -0,0 +1,561 @@
use std::ops::{Add, Div, Mul, Neg, Sub};
use num_traits::{One, Zero};
use super::fields::m31::{BaseField, M31};
use super::fields::qm31::SecureField;
use super::fields::{ComplexConjugate, Field, FieldExpOps};
use crate::core::channel::Channel;
use crate::core::fields::qm31::P4;
use crate::math::utils::egcd;
/// A point on the complex circle. Treated as an additive group.
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct CirclePoint<F> {
pub x: F,
pub y: F,
}
impl<F: Zero + Add<Output = F> + FieldExpOps + Sub<Output = F> + Neg<Output = F>> CirclePoint<F> {
pub fn zero() -> Self {
Self {
x: F::one(),
y: F::zero(),
}
}
pub fn double(&self) -> Self {
*self + *self
}
/// Applies the circle's x-coordinate doubling map.
///
/// # Examples
///
/// ```
/// use stwo_prover::core::circle::{CirclePoint, M31_CIRCLE_GEN};
/// use stwo_prover::core::fields::m31::M31;
/// let p = M31_CIRCLE_GEN.mul(17);
/// assert_eq!(CirclePoint::double_x(p.x), (p + p).x);
/// ```
pub fn double_x(x: F) -> F {
let sx = x.square();
sx + sx - F::one()
}
/// Returns the log order of a point.
///
/// All points have an order of the form `2^k`.
///
/// # Examples
///
/// ```
/// use stwo_prover::core::circle::{CirclePoint, M31_CIRCLE_GEN, M31_CIRCLE_LOG_ORDER};
/// use stwo_prover::core::fields::m31::M31;
/// assert_eq!(M31_CIRCLE_GEN.log_order(), M31_CIRCLE_LOG_ORDER);
/// ```
pub fn log_order(&self) -> u32
where
F: PartialEq + Eq,
{
// we only need the x-coordinate to check order since the only point
// with x=1 is the circle's identity
let mut res = 0;
let mut cur = self.x;
while cur != F::one() {
cur = Self::double_x(cur);
res += 1;
}
res
}
pub fn mul(&self, mut scalar: u128) -> CirclePoint<F> {
let mut res = Self::zero();
let mut cur = *self;
while scalar > 0 {
if scalar & 1 == 1 {
res = res + cur;
}
cur = cur.double();
scalar >>= 1;
}
res
}
pub fn repeated_double(&self, n: u32) -> Self {
let mut res = *self;
for _ in 0..n {
res = res.double();
}
res
}
pub fn conjugate(&self) -> CirclePoint<F> {
Self {
x: self.x,
y: -self.y,
}
}
pub fn antipode(&self) -> CirclePoint<F> {
Self {
x: -self.x,
y: -self.y,
}
}
pub fn into_ef<EF: From<F>>(&self) -> CirclePoint<EF> {
CirclePoint {
x: self.x.into(),
y: self.y.into(),
}
}
pub fn mul_signed(&self, off: isize) -> CirclePoint<F> {
if off > 0 {
self.mul(off as u128)
} else {
self.conjugate().mul(-off as u128)
}
}
}
impl<F: Zero + Add<Output = F> + FieldExpOps + Sub<Output = F> + Neg<Output = F>> Add
for CirclePoint<F>
{
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
let x = self.x * rhs.x - self.y * rhs.y;
let y = self.x * rhs.y + self.y * rhs.x;
Self { x, y }
}
}
impl<F: Zero + Add<Output = F> + FieldExpOps + Sub<Output = F> + Neg<Output = F>> Neg
for CirclePoint<F>
{
type Output = Self;
fn neg(self) -> Self::Output {
self.conjugate()
}
}
impl<F: Zero + Add<Output = F> + FieldExpOps + Sub<Output = F> + Neg<Output = F>> Sub
for CirclePoint<F>
{
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
self + (-rhs)
}
}
impl<F: Field> ComplexConjugate for CirclePoint<F> {
fn complex_conjugate(&self) -> Self {
Self {
x: self.x.complex_conjugate(),
y: self.y.complex_conjugate(),
}
}
}
impl CirclePoint<SecureField> {
pub fn get_point(index: u128) -> Self {
assert!(index < SECURE_FIELD_CIRCLE_ORDER);
SECURE_FIELD_CIRCLE_GEN.mul(index)
}
pub fn get_random_point<C: Channel>(channel: &mut C) -> Self {
let t = channel.draw_felt();
let t_square = t.square();
let one_plus_tsquared_inv = t_square.add(SecureField::one()).inverse();
let x = SecureField::one()
.add(t_square.neg())
.mul(one_plus_tsquared_inv);
let y = t.double().mul(one_plus_tsquared_inv);
Self { x, y }
}
}
/// A generator for the circle group over [M31].
///
/// # Examples
///
/// ```
/// use stwo_prover::core::circle::{CirclePoint, M31_CIRCLE_GEN};
/// use stwo_prover::core::fields::m31::M31;
///
/// // Adding a generator to itself (2^30) times should NOT yield the identity.
/// let circle_point = M31_CIRCLE_GEN.repeated_double(30);
/// assert_ne!(circle_point, CirclePoint::zero());
///
/// // Shown above ord(M31_CIRCLE_GEN) > 2^30 . Group order is 2^31.
/// // Ord(M31_CIRCLE_GEN) must be a divisor of it, Hence ord(M31_CIRCLE_GEN) = 2^31.
/// // Adding the generator to itself (2^31) times should yield the identity.
/// let circle_point = M31_CIRCLE_GEN.repeated_double(31);
/// assert_eq!(circle_point, CirclePoint::zero());
/// ```
pub const M31_CIRCLE_GEN: CirclePoint<M31> = CirclePoint {
x: M31::from_u32_unchecked(2),
y: M31::from_u32_unchecked(1268011823),
};
/// Order of [M31_CIRCLE_GEN].
pub const M31_CIRCLE_LOG_ORDER: u32 = 31;
/// A generator for the circle group over [SecureField].
pub const SECURE_FIELD_CIRCLE_GEN: CirclePoint<SecureField> = CirclePoint {
x: SecureField::from_u32_unchecked(1, 0, 478637715, 513582971),
y: SecureField::from_u32_unchecked(992285211, 649143431, 740191619, 1186584352),
};
/// Order of [SECURE_FIELD_CIRCLE_GEN].
pub const SECURE_FIELD_CIRCLE_ORDER: u128 = P4 - 1;
/// Integer i that represent the circle point i * CIRCLE_GEN. Treated as an
/// additive ring modulo `1 << M31_CIRCLE_LOG_ORDER`.
#[derive(Copy, Clone, Debug, PartialEq, Eq, Ord, PartialOrd)]
pub struct CirclePointIndex(pub usize);
impl CirclePointIndex {
pub fn zero() -> Self {
Self(0)
}
pub fn generator() -> Self {
Self(1)
}
pub fn reduce(self) -> Self {
Self(self.0 & ((1 << M31_CIRCLE_LOG_ORDER) - 1))
}
pub fn subgroup_gen(log_size: u32) -> Self {
assert!(log_size <= M31_CIRCLE_LOG_ORDER);
Self(1 << (M31_CIRCLE_LOG_ORDER - log_size))
}
pub fn to_point(self) -> CirclePoint<M31> {
M31_CIRCLE_GEN.mul(self.0 as u128)
}
pub fn half(self) -> Self {
assert!(self.0 & 1 == 0);
Self(self.0 >> 1)
}
pub fn try_div(&self, rhs: CirclePointIndex) -> Option<isize> {
// Find x s.t. x * rhs.0 = self.0 (mod CIRCLE_ORDER).
let (s, _t, g) = egcd(rhs.0 as isize, 1 << M31_CIRCLE_LOG_ORDER);
if self.0 as isize % g != 0 {
return None;
}
let res = s * self.0 as isize / g;
Some(res)
}
}
impl Add for CirclePointIndex {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
Self(self.0 + rhs.0).reduce()
}
}
impl Sub for CirclePointIndex {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
Self(self.0 + (1 << M31_CIRCLE_LOG_ORDER) - rhs.0).reduce()
}
}
impl Mul<usize> for CirclePointIndex {
type Output = Self;
fn mul(self, rhs: usize) -> Self::Output {
Self(self.0.wrapping_mul(rhs)).reduce()
}
}
impl Div for CirclePointIndex {
type Output = isize;
fn div(self, rhs: Self) -> Self::Output {
self.try_div(rhs).unwrap()
}
}
impl Neg for CirclePointIndex {
type Output = Self;
fn neg(self) -> Self::Output {
Self((1 << M31_CIRCLE_LOG_ORDER) - self.0).reduce()
}
}
/// Represents the coset initial + \<step\>.
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct Coset {
pub initial_index: CirclePointIndex,
pub initial: CirclePoint<M31>,
pub step_size: CirclePointIndex,
pub step: CirclePoint<M31>,
pub log_size: u32,
}
impl Coset {
pub fn new(initial_index: CirclePointIndex, log_size: u32) -> Self {
assert!(log_size <= M31_CIRCLE_LOG_ORDER);
let step_size = CirclePointIndex::subgroup_gen(log_size);
Self {
initial_index,
initial: initial_index.to_point(),
step: step_size.to_point(),
step_size,
log_size,
}
}
/// Creates a coset of the form <G_n>.
/// For example, for n=8, we get the point indices \[0,1,2,3,4,5,6,7\].
pub fn subgroup(log_size: u32) -> Self {
Self::new(CirclePointIndex::zero(), log_size)
}
/// Creates a coset of the form G_2n + \<G_n\>.
/// For example, for n=8, we get the point indices \[1,3,5,7,9,11,13,15\].
pub fn odds(log_size: u32) -> Self {
Self::new(CirclePointIndex::subgroup_gen(log_size + 1), log_size)
}
/// Creates a coset of the form G_4n + <G_n>.
/// For example, for n=8, we get the point indices \[1,5,9,13,17,21,25,29\].
/// Its conjugate will be \[3,7,11,15,19,23,27,31\].
pub fn half_odds(log_size: u32) -> Self {
Self::new(CirclePointIndex::subgroup_gen(log_size + 2), log_size)
}
/// Returns the size of the coset.
pub fn size(&self) -> usize {
1 << self.log_size()
}
/// Returns the log size of the coset.
pub fn log_size(&self) -> u32 {
self.log_size
}
pub fn iter(&self) -> CosetIterator<CirclePoint<M31>> {
CosetIterator {
cur: self.initial,
step: self.step,
remaining: self.size(),
}
}
pub fn iter_indices(&self) -> CosetIterator<CirclePointIndex> {
CosetIterator {
cur: self.initial_index,
step: self.step_size,
remaining: self.size(),
}
}
/// Returns a new coset comprising of all points in current coset doubled.
pub fn double(&self) -> Self {
assert!(self.log_size > 0);
Self {
initial_index: self.initial_index * 2,
initial: self.initial.double(),
step: self.step.double(),
step_size: self.step_size * 2,
log_size: self.log_size.saturating_sub(1),
}
}
pub fn repeated_double(&self, n_doubles: u32) -> Self {
(0..n_doubles).fold(*self, |coset, _| coset.double())
}
pub fn is_doubling_of(&self, other: Self) -> bool {
self.log_size <= other.log_size
&& *self == other.repeated_double(other.log_size - self.log_size)
}
pub fn initial(&self) -> CirclePoint<M31> {
self.initial
}
pub fn index_at(&self, index: usize) -> CirclePointIndex {
self.initial_index + self.step_size.mul(index)
}
pub fn at(&self, index: usize) -> CirclePoint<M31> {
self.index_at(index).to_point()
}
pub fn shift(&self, shift_size: CirclePointIndex) -> Self {
let initial_index = self.initial_index + shift_size;
Self {
initial_index,
initial: initial_index.to_point(),
..*self
}
}
/// Creates the conjugate coset: -initial -\<step\>.
pub fn conjugate(&self) -> Self {
let initial_index = -self.initial_index;
let step_size = -self.step_size;
Self {
initial_index,
initial: initial_index.to_point(),
step_size,
step: step_size.to_point(),
log_size: self.log_size,
}
}
pub fn find(&self, i: CirclePointIndex) -> Option<usize> {
let res = (i - self.initial_index).try_div(self.step_size)?;
Some(res.rem_euclid(self.size() as isize) as usize)
}
}
impl IntoIterator for Coset {
type Item = CirclePoint<BaseField>;
type IntoIter = CosetIterator<CirclePoint<BaseField>>;
/// Iterates over the points in the coset.
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
#[derive(Clone)]
pub struct CosetIterator<T: Add> {
pub cur: T,
pub step: T,
pub remaining: usize,
}
impl<T: Add<Output = T> + Copy> Iterator for CosetIterator<T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
if self.remaining == 0 {
return None;
}
self.remaining -= 1;
let res = self.cur;
self.cur = self.cur + self.step;
Some(res)
}
}
#[cfg(test)]
mod tests {
use std::collections::BTreeSet;
use num_traits::{One, Pow};
use super::{CirclePointIndex, Coset};
use crate::core::channel::Blake2sChannel;
use crate::core::circle::{CirclePoint, SECURE_FIELD_CIRCLE_GEN};
use crate::core::fields::qm31::{SecureField, P4};
use crate::core::fields::FieldExpOps;
use crate::core::poly::circle::CanonicCoset;
#[test]
fn test_iterator() {
let coset = Coset::new(CirclePointIndex(1), 3);
let actual_indices: Vec<_> = coset.iter_indices().collect();
let expected_indices = vec![
CirclePointIndex(1),
CirclePointIndex(1) + CirclePointIndex::subgroup_gen(3) * 1,
CirclePointIndex(1) + CirclePointIndex::subgroup_gen(3) * 2,
CirclePointIndex(1) + CirclePointIndex::subgroup_gen(3) * 3,
CirclePointIndex(1) + CirclePointIndex::subgroup_gen(3) * 4,
CirclePointIndex(1) + CirclePointIndex::subgroup_gen(3) * 5,
CirclePointIndex(1) + CirclePointIndex::subgroup_gen(3) * 6,
CirclePointIndex(1) + CirclePointIndex::subgroup_gen(3) * 7,
];
assert_eq!(actual_indices, expected_indices);
let actual_points = coset.iter().collect::<Vec<_>>();
let expected_points: Vec<_> = expected_indices.iter().map(|i| i.to_point()).collect();
assert_eq!(actual_points, expected_points);
}
#[test]
fn test_coset_is_half_coset_with_conjugate() {
let canonic_coset = CanonicCoset::new(8);
let coset_points = BTreeSet::from_iter(canonic_coset.coset().iter());
let half_coset_points = BTreeSet::from_iter(canonic_coset.half_coset().iter());
let half_coset_conjugate_points =
BTreeSet::from_iter(canonic_coset.half_coset().conjugate().iter());
assert!((&half_coset_points & &half_coset_conjugate_points).is_empty());
assert_eq!(
coset_points,
&half_coset_points | &half_coset_conjugate_points
)
}
#[test]
pub fn test_get_random_circle_point() {
let mut channel = Blake2sChannel::default();
let first_random_circle_point = CirclePoint::get_random_point(&mut channel);
// Assert that the next random circle point is different.
assert_ne!(
first_random_circle_point,
CirclePoint::get_random_point(&mut channel)
);
}
#[test]
pub fn test_secure_field_circle_gen() {
let prime_factors = [
(2, 33),
(3, 2),
(5, 1),
(7, 1),
(11, 1),
(31, 1),
(151, 1),
(331, 1),
(733, 1),
(1709, 1),
(368140581013, 1),
];
assert_eq!(
prime_factors
.iter()
.map(|(p, e)| p.pow(*e as u32))
.product::<u128>(),
P4 - 1
);
assert_eq!(
SECURE_FIELD_CIRCLE_GEN.x.square() + SECURE_FIELD_CIRCLE_GEN.y.square(),
SecureField::one()
);
assert_eq!(SECURE_FIELD_CIRCLE_GEN.mul(P4 - 1), CirclePoint::zero());
for (p, _) in prime_factors.iter() {
assert_ne!(
SECURE_FIELD_CIRCLE_GEN.mul((P4 - 1) / *p),
CirclePoint::zero()
);
}
}
}

View File

@ -0,0 +1,251 @@
use num_traits::One;
use super::circle::{CirclePoint, Coset};
use super::fields::m31::BaseField;
use super::fields::qm31::SecureField;
use super::fields::ExtensionOf;
use super::pcs::quotients::PointSample;
use crate::core::fields::ComplexConjugate;
/// Evaluates a vanishing polynomial of the coset at a point.
pub fn coset_vanishing<F: ExtensionOf<BaseField>>(coset: Coset, mut p: CirclePoint<F>) -> F {
// Doubling a point `log_order - 1` times and taking the x coordinate is
// essentially evaluating a polynomial in x of degree `2^(log_order - 1)`. If
// the entire `2^log_order` points of the coset are roots (i.e. yield 0), then
// this is a vanishing polynomial of these points.
// Rotating the coset -coset.initial + step / 2 yields a canonic coset:
// `step/2 + <step>.`
// Doubling this coset log_order - 1 times yields the coset +-G_4.
// The polynomial x vanishes on these points.
// ```text
// X
// . .
// X
// ```
p = p - coset.initial.into_ef() + coset.step_size.half().to_point().into_ef();
//println!("p - coset.initial.into_ef() + coset.step_size.half().to_point().into_ef() = {:?} - {:?} + {:?}",p,coset.initial,coset.step_size.half().0);
let mut x = p.x;
// The formula for the x coordinate of the double of a point.
for _ in 1..coset.log_size {
x = CirclePoint::double_x(x);
}
x
}
/// Evaluates the polynomial that is used to exclude the excluded point at point
/// p. Note that this polynomial has a zero of multiplicity 2 at the excluded
/// point.
pub fn point_excluder<F: ExtensionOf<BaseField>>(
excluded: CirclePoint<BaseField>,
p: CirclePoint<F>,
) -> F {
(p - excluded.into_ef()).x - BaseField::one()
}
// A vanishing polynomial on 2 circle points.
pub fn pair_vanishing<F: ExtensionOf<BaseField>>(
excluded0: CirclePoint<F>,
excluded1: CirclePoint<F>,
p: CirclePoint<F>,
) -> F {
// The algorithm check computes the area of the triangle formed by the
// 3 points. This is done using the determinant of:
// | p.x p.y 1 |
// | e0.x e0.y 1 |
// | e1.x e1.y 1 |
// This is a polynomial of degree 1 in p.x and p.y, and thus it is a line.
// It vanishes at e0 and e1.
(excluded0.y - excluded1.y) * p.x
+ (excluded1.x - excluded0.x) * p.y
+ (excluded0.x * excluded1.y - excluded0.y * excluded1.x)
}
/// Evaluates a vanishing polynomial of the vanish_point at a point.
/// Note that this function has a pole on the antipode of the vanish_point.
pub fn point_vanishing<F: ExtensionOf<BaseField>, EF: ExtensionOf<F>>(
vanish_point: CirclePoint<F>,
p: CirclePoint<EF>,
) -> EF {
let h = p - vanish_point.into_ef();
h.y / (EF::one() + h.x)
}
/// Evaluates a point on a line between a point and its complex conjugate.
/// Relies on the fact that every polynomial F over the base field holds:
/// F(p*) == F(p)* (* being the complex conjugate).
pub fn complex_conjugate_line(
point: CirclePoint<SecureField>,
value: SecureField,
p: CirclePoint<BaseField>,
) -> SecureField {
// TODO(AlonH): This assertion will fail at a probability of 1 to 2^62. Use a better solution.
assert_ne!(
point.y,
point.y.complex_conjugate(),
"Cannot evaluate a line with a single point ({point:?})."
);
value
+ (value.complex_conjugate() - value) * (-point.y + p.y)
/ (point.complex_conjugate().y - point.y)
}
/// Evaluates the coefficients of a line between a point and its complex conjugate. Specifically,
/// `a, b, and c, s.t. a*x + b -c*y = 0` for (x,y) being (sample.y, sample.value) and
/// (conj(sample.y), conj(sample.value)).
/// Relies on the fact that every polynomial F over the base
/// field holds: F(p*) == F(p)* (* being the complex conjugate).
pub fn complex_conjugate_line_coeffs(
sample: &PointSample,
alpha: SecureField,
) -> (SecureField, SecureField, SecureField) {
// TODO(AlonH): This assertion will fail at a probability of 1 to 2^62. Use a better solution.
assert_ne!(
sample.point.y,
sample.point.y.complex_conjugate(),
"Cannot evaluate a line with a single point ({:?}).",
sample.point
);
let a = sample.value.complex_conjugate() - sample.value;
let c = sample.point.complex_conjugate().y - sample.point.y;
let b = sample.value * c - a * sample.point.y;
(alpha * a, alpha * b, alpha * c)
}
#[cfg(test)]
mod tests {
use num_traits::Zero;
use super::{coset_vanishing, point_excluder, point_vanishing};
use crate::core::backend::cpu::{CpuCircleEvaluation, CpuCirclePoly};
use crate::core::circle::{CirclePoint, CirclePointIndex, Coset};
use crate::core::constraints::{complex_conjugate_line, pair_vanishing};
use crate::core::fields::m31::{BaseField, M31};
use crate::core::fields::qm31::SecureField;
use crate::core::fields::{ComplexConjugate, FieldExpOps};
use crate::core::poly::circle::CanonicCoset;
use crate::core::poly::NaturalOrder;
use crate::core::test_utils::secure_eval_to_base_eval;
use crate::m31;
#[test]
fn test_coset_vanishing() {
let cosets = [
Coset::half_odds(5),
Coset::odds(5),
Coset::new(CirclePointIndex::zero(), 5),
Coset::half_odds(5).conjugate(),
];
for c0 in cosets.iter() {
for el in c0.iter() {
assert_eq!(coset_vanishing(*c0, el), BaseField::zero());
for c1 in cosets.iter() {
if c0 == c1 {
continue;
}
assert_ne!(coset_vanishing(*c1, el), BaseField::zero());
}
}
}
}
#[test]
fn test_point_excluder() {
let excluded = Coset::half_odds(5).at(10);
let point = (CirclePointIndex::generator() * 4).to_point();
let num = point_excluder(excluded, point) * point_excluder(excluded.conjugate(), point);
let denom = (point.x - excluded.x).pow(2);
assert_eq!(num, denom);
}
#[test]
fn test_pair_excluder() {
let excluded0 = Coset::half_odds(5).at(10);
let excluded1 = Coset::half_odds(5).at(13);
let point = (CirclePointIndex::generator() * 4).to_point();
assert_ne!(pair_vanishing(excluded0, excluded1, point), M31::zero());
assert_eq!(pair_vanishing(excluded0, excluded1, excluded0), M31::zero());
assert_eq!(pair_vanishing(excluded0, excluded1, excluded1), M31::zero());
}
#[test]
fn test_point_vanishing_success() {
let coset = Coset::odds(5);
let vanish_point = coset.at(2);
for el in coset.iter() {
if el == vanish_point {
assert_eq!(point_vanishing(vanish_point, el), BaseField::zero());
continue;
}
if el == vanish_point.antipode() {
continue;
}
assert_ne!(point_vanishing(vanish_point, el), BaseField::zero());
}
}
#[test]
#[should_panic(expected = "0 has no inverse")]
fn test_point_vanishing_failure() {
let coset = Coset::half_odds(6);
let point = coset.at(4);
point_vanishing(point, point.antipode());
}
#[test]
fn test_complex_conjugate_symmetry() {
// Create a polynomial over a base circle domain.
let polynomial = CpuCirclePoly::new((0..1 << 7).map(|i| m31!(i)).collect());
let oods_point = CirclePoint::get_point(9834759221);
// Assert that the base field polynomial is complex conjugate symmetric.
assert_eq!(
polynomial.eval_at_point(oods_point.complex_conjugate()),
polynomial.eval_at_point(oods_point).complex_conjugate()
);
}
#[test]
fn test_point_vanishing_degree() {
// Create a polynomial over a circle domain.
let log_domain_size = 7;
let domain_size = 1 << log_domain_size;
let polynomial = CpuCirclePoly::new((0..domain_size).map(|i| m31!(i)).collect());
// Create a larger domain.
let log_large_domain_size = log_domain_size + 1;
let large_domain_size = 1 << log_large_domain_size;
let large_domain = CanonicCoset::new(log_large_domain_size).circle_domain();
// Create a vanish point that is not in the large domain.
let vanish_point = CirclePoint::get_point(97);
let vanish_point_value = polynomial.eval_at_point(vanish_point);
// Compute the quotient polynomial.
let mut quotient_polynomial_values = Vec::with_capacity(large_domain_size as usize);
for point in large_domain.iter() {
let line = complex_conjugate_line(vanish_point, vanish_point_value, point);
let mut value = polynomial.eval_at_point(point.into_ef()) - line;
value /= pair_vanishing(
vanish_point,
vanish_point.complex_conjugate(),
point.into_ef(),
);
quotient_polynomial_values.push(value);
}
let quotient_evaluation = CpuCircleEvaluation::<SecureField, NaturalOrder>::new(
large_domain,
quotient_polynomial_values,
);
let quotient_polynomial = secure_eval_to_base_eval(&quotient_evaluation)
.bit_reverse()
.interpolate();
// Check that the quotient polynomial is indeed in the wanted fft space.
assert!(quotient_polynomial.is_in_fft_space(log_domain_size));
}
}

View File

@ -0,0 +1,21 @@
use std::ops::{Add, AddAssign, Mul, Sub};
use super::fields::m31::BaseField;
pub fn butterfly<F>(v0: &mut F, v1: &mut F, twid: BaseField)
where
F: Copy + AddAssign<F> + Sub<F, Output = F> + Mul<BaseField, Output = F>,
{
let tmp = *v1 * twid;
*v1 = *v0 - tmp;
*v0 += tmp;
}
pub fn ibutterfly<F>(v0: &mut F, v1: &mut F, itwid: BaseField)
where
F: Copy + AddAssign<F> + Add<F, Output = F> + Sub<F, Output = F> + Mul<BaseField, Output = F>,
{
let tmp = *v0;
*v0 = tmp + *v1;
*v1 = (tmp - *v1) * itwid;
}

View File

@ -0,0 +1,137 @@
use std::fmt::{Debug, Display};
use std::ops::{
Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
};
use serde::{Deserialize, Serialize};
use super::{ComplexConjugate, FieldExpOps};
use crate::core::fields::m31::M31;
use crate::{impl_extension_field, impl_field};
pub const P2: u64 = 4611686014132420609; // (2 ** 31 - 1) ** 2
/// Complex extension field of M31.
/// Equivalent to M31\[x\] over (x^2 + 1) as the irreducible polynomial.
/// Represented as (a, b) of a + bi.
#[derive(Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)]
pub struct CM31(pub M31, pub M31);
impl_field!(CM31, P2);
impl_extension_field!(CM31, M31);
impl CM31 {
pub const fn from_u32_unchecked(a: u32, b: u32) -> CM31 {
Self(M31::from_u32_unchecked(a), M31::from_u32_unchecked(b))
}
pub fn from_m31(a: M31, b: M31) -> CM31 {
Self(a, b)
}
}
impl Display for CM31 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{} + {}i", self.0, self.1)
}
}
impl Debug for CM31 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{} + {}i", self.0, self.1)
}
}
impl Mul for CM31 {
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
// (a + bi) * (c + di) = (ac - bd) + (ad + bc)i.
Self(
self.0 * rhs.0 - self.1 * rhs.1,
self.0 * rhs.1 + self.1 * rhs.0,
)
}
}
impl TryInto<M31> for CM31 {
type Error = ();
fn try_into(self) -> Result<M31, Self::Error> {
if self.1 != M31::zero() {
return Err(());
}
Ok(self.0)
}
}
impl FieldExpOps for CM31 {
fn inverse(&self) -> Self {
assert!(!self.is_zero(), "0 has no inverse");
// 1 / (a + bi) = (a - bi) / (a^2 + b^2).
Self(self.0, -self.1) * (self.0.square() + self.1.square()).inverse()
}
}
#[cfg(test)]
#[macro_export]
macro_rules! cm31 {
($m0:expr, $m1:expr) => {
CM31::from_u32_unchecked($m0, $m1)
};
}
#[cfg(test)]
mod tests {
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use super::CM31;
use crate::core::fields::m31::P;
use crate::core::fields::{FieldExpOps, IntoSlice};
use crate::m31;
#[test]
fn test_inverse() {
let cm = cm31!(1, 2);
let cm_inv = cm.inverse();
assert_eq!(cm * cm_inv, cm31!(1, 0));
}
#[test]
fn test_ops() {
let cm0 = cm31!(1, 2);
let cm1 = cm31!(4, 5);
let m = m31!(8);
let cm = CM31::from(m);
let cm0_x_cm1 = cm31!(P - 6, 13);
assert_eq!(cm0 + cm1, cm31!(5, 7));
assert_eq!(cm1 + m, cm1 + cm);
assert_eq!(cm0 * cm1, cm0_x_cm1);
assert_eq!(cm1 * m, cm1 * cm);
assert_eq!(-cm0, cm31!(P - 1, P - 2));
assert_eq!(cm0 - cm1, cm31!(P - 3, P - 3));
assert_eq!(cm1 - m, cm1 - cm);
assert_eq!(cm0_x_cm1 / cm1, cm31!(1, 2));
assert_eq!(cm1 / m, cm1 / cm);
}
#[test]
fn test_into_slice() {
let mut rng = SmallRng::seed_from_u64(0);
let x = (0..100).map(|_| rng.gen()).collect::<Vec<CM31>>();
let slice = CM31::into_slice(&x);
for i in 0..100 {
let corresponding_sub_slice = &slice[i * 8..(i + 1) * 8];
assert_eq!(
x[i],
cm31!(
u32::from_le_bytes(corresponding_sub_slice[..4].try_into().unwrap()),
u32::from_le_bytes(corresponding_sub_slice[4..].try_into().unwrap())
)
)
}
}
}

View File

@ -0,0 +1,258 @@
use std::fmt::Display;
use std::ops::{
Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
};
use bytemuck::{Pod, Zeroable};
use rand::distributions::{Distribution, Standard};
use serde::{Deserialize, Serialize};
use super::{ComplexConjugate, FieldExpOps};
use crate::impl_field;
pub const MODULUS_BITS: u32 = 31;
pub const N_BYTES_FELT: usize = 4;
pub const P: u32 = 2147483647; // 2 ** 31 - 1
#[repr(transparent)]
#[derive(
Copy,
Clone,
Debug,
Default,
PartialEq,
Eq,
PartialOrd,
Ord,
Hash,
Pod,
Zeroable,
Serialize,
Deserialize,
)]
pub struct M31(pub u32);
pub type BaseField = M31;
impl_field!(M31, P);
impl M31 {
/// Returns `val % P` when `val` is in the range `[0, 2P)`.
///
/// ```
/// use stwo_prover::core::fields::m31::{M31, P};
///
/// let val = 2 * P - 19;
/// assert_eq!(M31::partial_reduce(val), M31::from(P - 19));
/// ```
pub fn partial_reduce(val: u32) -> Self {
Self(val.checked_sub(P).unwrap_or(val))
}
/// Returns `val % P` when `val` is in the range `[0, P^2)`.
///
/// ```
/// use stwo_prover::core::fields::m31::{M31, P};
///
/// let val = (P as u64).pow(2) - 19;
/// assert_eq!(M31::reduce(val), M31::from(P - 19));
/// ```
pub fn reduce(val: u64) -> Self {
Self((((((val >> MODULUS_BITS) + val + 1) >> MODULUS_BITS) + val) & (P as u64)) as u32)
}
pub const fn from_u32_unchecked(arg: u32) -> Self {
Self(arg)
}
}
impl Display for M31 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl Add for M31 {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
Self::partial_reduce(self.0 + rhs.0)
}
}
impl Neg for M31 {
type Output = Self;
fn neg(self) -> Self::Output {
Self::partial_reduce(P - self.0)
}
}
impl Sub for M31 {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
Self::partial_reduce(self.0 + P - rhs.0)
}
}
impl Mul for M31 {
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
Self::reduce((self.0 as u64) * (rhs.0 as u64))
}
}
impl FieldExpOps for M31 {
/// ```
/// use num_traits::One;
/// use stwo_prover::core::fields::m31::BaseField;
/// use stwo_prover::core::fields::FieldExpOps;
///
/// let v = BaseField::from(19);
/// assert_eq!(v.inverse() * v, BaseField::one());
/// ```
fn inverse(&self) -> Self {
assert!(!self.is_zero(), "0 has no inverse");
pow2147483645(*self)
}
}
impl ComplexConjugate for M31 {
fn complex_conjugate(&self) -> Self {
*self
}
}
impl One for M31 {
fn one() -> Self {
Self(1)
}
}
impl Zero for M31 {
fn zero() -> Self {
Self(0)
}
fn is_zero(&self) -> bool {
*self == Self::zero()
}
}
impl From<usize> for M31 {
fn from(value: usize) -> Self {
M31::reduce(value.try_into().unwrap())
}
}
impl From<u32> for M31 {
fn from(value: u32) -> Self {
M31::reduce(value.into())
}
}
impl From<i32> for M31 {
fn from(value: i32) -> Self {
M31::reduce(value.try_into().unwrap())
}
}
impl Distribution<M31> for Standard {
// Not intended for cryptographic use. Should only be used in tests and benchmarks.
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> M31 {
M31(rng.gen_range(0..P))
}
}
#[cfg(test)]
#[macro_export]
macro_rules! m31 {
($m:expr) => {
$crate::core::fields::m31::M31::from_u32_unchecked($m)
};
}
/// Computes `v^((2^31-1)-2)`.
///
/// Computes the multiplicative inverse of [`M31`] elements with 37 multiplications vs naive 60
/// multiplications. Made generic to support both vectorized and non-vectorized implementations.
/// Multiplication tree found with [addchain](https://github.com/mmcloughlin/addchain).
///
/// ```
/// use stwo_prover::core::fields::m31::{pow2147483645, BaseField};
/// use stwo_prover::core::fields::FieldExpOps;
///
/// let v = BaseField::from(19);
/// assert_eq!(pow2147483645(v), v.pow(2147483645));
/// ```
pub fn pow2147483645<T: FieldExpOps>(v: T) -> T {
let t0 = sqn::<2, T>(v) * v;
let t1 = sqn::<1, T>(t0) * t0;
let t2 = sqn::<3, T>(t1) * t0;
let t3 = sqn::<1, T>(t2) * t0;
let t4 = sqn::<8, T>(t3) * t3;
let t5 = sqn::<8, T>(t4) * t3;
sqn::<7, T>(t5) * t2
}
/// Computes `v^(2*n)`.
fn sqn<const N: usize, T: FieldExpOps>(mut v: T) -> T {
for _ in 0..N {
v = v.square();
}
v
}
#[cfg(test)]
mod tests {
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use super::{M31, P};
use crate::core::fields::IntoSlice;
fn mul_p(a: u32, b: u32) -> u32 {
((a as u64 * b as u64) % P as u64) as u32
}
fn add_p(a: u32, b: u32) -> u32 {
(a + b) % P
}
fn neg_p(a: u32) -> u32 {
if a == 0 {
0
} else {
P - a
}
}
#[test]
fn test_basic_ops() {
let mut rng = SmallRng::seed_from_u64(0);
for _ in 0..10000 {
let x: u32 = rng.gen::<u32>() % P;
let y: u32 = rng.gen::<u32>() % P;
assert_eq!(m31!(add_p(x, y)), m31!(x) + m31!(y));
assert_eq!(m31!(mul_p(x, y)), m31!(x) * m31!(y));
assert_eq!(m31!(neg_p(x)), -m31!(x));
}
}
#[test]
fn test_into_slice() {
let mut rng = SmallRng::seed_from_u64(0);
let x = (0..100).map(|_| rng.gen()).collect::<Vec<M31>>();
let slice = M31::into_slice(&x);
for i in 0..100 {
assert_eq!(
x[i],
m31!(u32::from_le_bytes(
slice[i * 4..(i + 1) * 4].try_into().unwrap()
))
);
}
}
}

View File

@ -0,0 +1,489 @@
use std::fmt::{Debug, Display};
use std::iter::{Product, Sum};
use std::ops::{Mul, MulAssign, Neg};
use num_traits::{NumAssign, NumAssignOps, NumOps, One};
use super::backend::ColumnOps;
pub mod cm31;
pub mod m31;
pub mod qm31;
pub mod secure_column;
pub trait FieldOps<F: Field>: ColumnOps<F> {
// TODO(Ohad): change to use a mutable slice.
fn batch_inverse(column: &Self::Column, dst: &mut Self::Column);
}
pub trait FieldExpOps: Mul<Output = Self> + MulAssign + Sized + One + Copy {
fn square(&self) -> Self {
(*self) * (*self)
}
fn pow(&self, exp: u128) -> Self {
let mut res = Self::one();
let mut base = *self;
let mut exp = exp;
while exp > 0 {
if exp & 1 == 1 {
res *= base;
}
base = base.square();
exp >>= 1;
}
res
}
fn inverse(&self) -> Self;
/// Inverts a batch of elements using Montgomery's trick.
fn batch_inverse(column: &[Self], dst: &mut [Self]) {
const WIDTH: usize = 4;
let n = column.len();
debug_assert!(dst.len() >= n);
if n <= WIDTH || n % WIDTH != 0 {
batch_inverse_classic(column, dst);
return;
}
// First pass. Compute 'WIDTH' cumulative products in an interleaving fashion, reducing
// instruction dependency and allowing better pipelining.
let mut cum_prod: [Self; WIDTH] = [Self::one(); WIDTH];
dst[..WIDTH].copy_from_slice(&cum_prod);
for i in 0..n {
cum_prod[i % WIDTH] *= column[i];
dst[i] = cum_prod[i % WIDTH];
}
// Inverse cumulative products.
// Use classic batch inversion.
let mut tail_inverses = [Self::one(); WIDTH];
batch_inverse_classic(&dst[n - WIDTH..], &mut tail_inverses);
// Second pass.
for i in (WIDTH..n).rev() {
dst[i] = dst[i - WIDTH] * tail_inverses[i % WIDTH];
tail_inverses[i % WIDTH] *= column[i];
}
dst[0..WIDTH].copy_from_slice(&tail_inverses);
}
}
/// Assumes dst is initialized and of the same length as column.
fn batch_inverse_classic<T: FieldExpOps>(column: &[T], dst: &mut [T]) {
let n = column.len();
debug_assert!(dst.len() >= n);
dst[0] = column[0];
// First pass.
for i in 1..n {
dst[i] = dst[i - 1] * column[i];
}
// Inverse cumulative product.
let mut curr_inverse = dst[n - 1].inverse();
// Second pass.
for i in (1..n).rev() {
dst[i] = dst[i - 1] * curr_inverse;
curr_inverse *= column[i];
}
dst[0] = curr_inverse;
}
pub trait Field:
NumAssign
+ Neg<Output = Self>
+ ComplexConjugate
+ Copy
+ Default
+ Debug
+ Display
+ PartialOrd
+ Ord
+ Send
+ Sync
+ Sized
+ FieldExpOps
+ Product
+ for<'a> Product<&'a Self>
+ Sum
+ for<'a> Sum<&'a Self>
{
fn double(&self) -> Self {
(*self) + (*self)
}
}
/// # Safety
///
/// Do not use unless you are aware of the endianess in the platform you are compiling for, and the
/// Field element's representation in memory.
// TODO(Ohad): Do not compile on non-le targets.
pub unsafe trait IntoSlice<T: Sized>: Sized {
fn into_slice(sl: &[Self]) -> &[T] {
unsafe {
std::slice::from_raw_parts(
sl.as_ptr() as *const T,
std::mem::size_of_val(sl) / std::mem::size_of::<T>(),
)
}
}
}
unsafe impl<F: Field> IntoSlice<u8> for F {}
pub trait ComplexConjugate {
/// # Example
///
/// ```
/// use stwo_prover::core::fields::m31::P;
/// use stwo_prover::core::fields::qm31::QM31;
/// use stwo_prover::core::fields::ComplexConjugate;
///
/// let x = QM31::from_u32_unchecked(1, 2, 3, 4);
/// assert_eq!(
/// x.complex_conjugate(),
/// QM31::from_u32_unchecked(1, 2, P - 3, P - 4)
/// );
/// ```
fn complex_conjugate(&self) -> Self;
}
pub trait ExtensionOf<F: Field>: Field + From<F> + NumOps<F> + NumAssignOps<F> {
const EXTENSION_DEGREE: usize;
}
impl<F: Field> ExtensionOf<F> for F {
const EXTENSION_DEGREE: usize = 1;
}
#[macro_export]
macro_rules! impl_field {
($field_name: ty, $field_size: ident) => {
use std::iter::{Product, Sum};
use num_traits::{Num, One, Zero};
use $crate::core::fields::Field;
impl Num for $field_name {
type FromStrRadixErr = Box<dyn std::error::Error>;
fn from_str_radix(_str: &str, _radix: u32) -> Result<Self, Self::FromStrRadixErr> {
unimplemented!(
"Num::from_str_radix is not implemented for {}",
stringify!($field_name)
);
}
}
impl Field for $field_name {}
impl AddAssign for $field_name {
fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs;
}
}
impl SubAssign for $field_name {
fn sub_assign(&mut self, rhs: Self) {
*self = *self - rhs;
}
}
impl MulAssign for $field_name {
fn mul_assign(&mut self, rhs: Self) {
*self = *self * rhs;
}
}
impl Div for $field_name {
type Output = Self;
#[allow(clippy::suspicious_arithmetic_impl)]
fn div(self, rhs: Self) -> Self::Output {
self * rhs.inverse()
}
}
impl DivAssign for $field_name {
fn div_assign(&mut self, rhs: Self) {
*self = *self / rhs;
}
}
impl Rem for $field_name {
type Output = Self;
fn rem(self, _rhs: Self) -> Self::Output {
unimplemented!("Rem is not implemented for {}", stringify!($field_name));
}
}
impl RemAssign for $field_name {
fn rem_assign(&mut self, _rhs: Self) {
unimplemented!(
"RemAssign is not implemented for {}",
stringify!($field_name)
);
}
}
impl Product for $field_name {
fn product<I>(mut iter: I) -> Self
where
I: Iterator<Item = Self>,
{
let first = iter.next().unwrap_or_else(Self::one);
iter.fold(first, |a, b| a * b)
}
}
impl<'a> Product<&'a Self> for $field_name {
fn product<I>(iter: I) -> Self
where
I: Iterator<Item = &'a Self>,
{
iter.map(|&v| v).product()
}
}
impl Sum for $field_name {
fn sum<I>(mut iter: I) -> Self
where
I: Iterator<Item = Self>,
{
let first = iter.next().unwrap_or_else(Self::zero);
iter.fold(first, |a, b| a + b)
}
}
impl<'a> Sum<&'a Self> for $field_name {
fn sum<I>(iter: I) -> Self
where
I: Iterator<Item = &'a Self>,
{
iter.map(|&v| v).sum()
}
}
};
}
/// Used to extend a field (with characteristic M31) by 2.
#[macro_export]
macro_rules! impl_extension_field {
($field_name: ident, $extended_field_name: ty) => {
use rand::distributions::{Distribution, Standard};
use $crate::core::fields::ExtensionOf;
impl ExtensionOf<M31> for $field_name {
const EXTENSION_DEGREE: usize =
<$extended_field_name as ExtensionOf<M31>>::EXTENSION_DEGREE * 2;
}
impl Add for $field_name {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
Self(self.0 + rhs.0, self.1 + rhs.1)
}
}
impl Neg for $field_name {
type Output = Self;
fn neg(self) -> Self::Output {
Self(-self.0, -self.1)
}
}
impl Sub for $field_name {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
Self(self.0 - rhs.0, self.1 - rhs.1)
}
}
impl One for $field_name {
fn one() -> Self {
Self(
<$extended_field_name>::one(),
<$extended_field_name>::zero(),
)
}
}
impl Zero for $field_name {
fn zero() -> Self {
Self(
<$extended_field_name>::zero(),
<$extended_field_name>::zero(),
)
}
fn is_zero(&self) -> bool {
*self == Self::zero()
}
}
impl Add<M31> for $field_name {
type Output = Self;
fn add(self, rhs: M31) -> Self::Output {
Self(self.0 + rhs, self.1)
}
}
impl Add<$field_name> for M31 {
type Output = $field_name;
fn add(self, rhs: $field_name) -> Self::Output {
rhs + self
}
}
impl Sub<M31> for $field_name {
type Output = Self;
fn sub(self, rhs: M31) -> Self::Output {
Self(self.0 - rhs, self.1)
}
}
impl Sub<$field_name> for M31 {
type Output = $field_name;
fn sub(self, rhs: $field_name) -> Self::Output {
-rhs + self
}
}
impl Mul<M31> for $field_name {
type Output = Self;
fn mul(self, rhs: M31) -> Self::Output {
Self(self.0 * rhs, self.1 * rhs)
}
}
impl Mul<$field_name> for M31 {
type Output = $field_name;
fn mul(self, rhs: $field_name) -> Self::Output {
rhs * self
}
}
impl Div<M31> for $field_name {
type Output = Self;
fn div(self, rhs: M31) -> Self::Output {
Self(self.0 / rhs, self.1 / rhs)
}
}
impl Div<$field_name> for M31 {
type Output = $field_name;
#[allow(clippy::suspicious_arithmetic_impl)]
fn div(self, rhs: $field_name) -> Self::Output {
rhs.inverse() * self
}
}
impl ComplexConjugate for $field_name {
fn complex_conjugate(&self) -> Self {
Self(self.0, -self.1)
}
}
impl From<M31> for $field_name {
fn from(x: M31) -> Self {
Self(x.into(), <$extended_field_name>::zero())
}
}
impl AddAssign<M31> for $field_name {
fn add_assign(&mut self, rhs: M31) {
*self = *self + rhs;
}
}
impl SubAssign<M31> for $field_name {
fn sub_assign(&mut self, rhs: M31) {
*self = *self - rhs;
}
}
impl MulAssign<M31> for $field_name {
fn mul_assign(&mut self, rhs: M31) {
*self = *self * rhs;
}
}
impl DivAssign<M31> for $field_name {
fn div_assign(&mut self, rhs: M31) {
*self = *self / rhs;
}
}
impl Rem<M31> for $field_name {
type Output = Self;
fn rem(self, _rhs: M31) -> Self::Output {
unimplemented!("Rem is not implemented for {}", stringify!($field_name));
}
}
impl RemAssign<M31> for $field_name {
fn rem_assign(&mut self, _rhs: M31) {
unimplemented!(
"RemAssign is not implemented for {}",
stringify!($field_name)
);
}
}
impl Distribution<$field_name> for Standard {
// Not intended for cryptographic use. Should only be used in tests and benchmarks.
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> $field_name {
$field_name(rng.gen(), rng.gen())
}
}
};
}
#[cfg(test)]
mod tests {
use num_traits::Zero;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use crate::core::fields::m31::M31;
use crate::core::fields::FieldExpOps;
#[test]
fn test_slice_batch_inverse() {
let mut rng = SmallRng::seed_from_u64(0);
let elements: [M31; 16] = rng.gen();
let expected = elements.iter().map(|e| e.inverse()).collect::<Vec<_>>();
let mut dst = [M31::zero(); 16];
M31::batch_inverse(&elements, &mut dst);
assert_eq!(expected, dst);
}
#[test]
#[should_panic]
fn test_slice_batch_inverse_wrong_dst_size() {
let mut rng = SmallRng::seed_from_u64(0);
let elements: [M31; 16] = rng.gen();
let mut dst = [M31::zero(); 15];
M31::batch_inverse(&elements, &mut dst);
}
}

View File

@ -0,0 +1,195 @@
use std::fmt::{Debug, Display};
use std::ops::{
Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
};
use serde::{Deserialize, Serialize};
use super::secure_column::SECURE_EXTENSION_DEGREE;
use super::{ComplexConjugate, FieldExpOps};
use crate::core::fields::cm31::CM31;
use crate::core::fields::m31::M31;
use crate::{impl_extension_field, impl_field};
pub const P4: u128 = 21267647892944572736998860269687930881; // (2 ** 31 - 1) ** 4
pub const R: CM31 = CM31::from_u32_unchecked(2, 1);
/// Extension field of CM31.
/// Equivalent to CM31\[x\] over (x^2 - 2 - i) as the irreducible polynomial.
/// Represented as ((a, b), (c, d)) of (a + bi) + (c + di)u.
#[derive(Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)]
pub struct QM31(pub CM31, pub CM31);
pub type SecureField = QM31;
impl_field!(QM31, P4);
impl_extension_field!(QM31, CM31);
impl QM31 {
pub const fn from_u32_unchecked(a: u32, b: u32, c: u32, d: u32) -> Self {
Self(
CM31::from_u32_unchecked(a, b),
CM31::from_u32_unchecked(c, d),
)
}
pub fn from_m31(a: M31, b: M31, c: M31, d: M31) -> Self {
Self(CM31::from_m31(a, b), CM31::from_m31(c, d))
}
pub fn from_m31_array(array: [M31; SECURE_EXTENSION_DEGREE]) -> Self {
Self::from_m31(array[0], array[1], array[2], array[3])
}
pub fn to_m31_array(self) -> [M31; SECURE_EXTENSION_DEGREE] {
[self.0 .0, self.0 .1, self.1 .0, self.1 .1]
}
/// Returns the combined value, given the values of its composing base field polynomials at that
/// point.
pub fn from_partial_evals(evals: [Self; SECURE_EXTENSION_DEGREE]) -> Self {
let mut res = evals[0];
res += evals[1] * Self::from_u32_unchecked(0, 1, 0, 0);
res += evals[2] * Self::from_u32_unchecked(0, 0, 1, 0);
res += evals[3] * Self::from_u32_unchecked(0, 0, 0, 1);
res
}
// Note: Adding this as a Mul impl drives rust insane, and it tries to infer Qm31*Qm31 as
// QM31*CM31.
pub fn mul_cm31(self, rhs: CM31) -> Self {
Self(self.0 * rhs, self.1 * rhs)
}
}
impl Display for QM31 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "({}) + ({})u", self.0, self.1)
}
}
impl Debug for QM31 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "({}) + ({})u", self.0, self.1)
}
}
impl Mul for QM31 {
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
// (a + bu) * (c + du) = (ac + rbd) + (ad + bc)u.
Self(
self.0 * rhs.0 + R * self.1 * rhs.1,
self.0 * rhs.1 + self.1 * rhs.0,
)
}
}
impl From<usize> for QM31 {
fn from(value: usize) -> Self {
M31::from(value).into()
}
}
impl From<u32> for QM31 {
fn from(value: u32) -> Self {
M31::from(value).into()
}
}
impl From<i32> for QM31 {
fn from(value: i32) -> Self {
M31::from(value).into()
}
}
impl TryInto<M31> for QM31 {
type Error = ();
fn try_into(self) -> Result<M31, Self::Error> {
if self.1 != CM31::zero() {
return Err(());
}
self.0.try_into()
}
}
impl FieldExpOps for QM31 {
fn inverse(&self) -> Self {
assert!(!self.is_zero(), "0 has no inverse");
// (a + bu)^-1 = (a - bu) / (a^2 - (2+i)b^2).
let b2 = self.1.square();
let ib2 = CM31(-b2.1, b2.0);
let denom = self.0.square() - (b2 + b2 + ib2);
let denom_inverse = denom.inverse();
Self(self.0 * denom_inverse, -self.1 * denom_inverse)
}
}
#[cfg(test)]
#[macro_export]
macro_rules! qm31 {
($m0:expr, $m1:expr, $m2:expr, $m3:expr) => {{
use $crate::core::fields::qm31::QM31;
QM31::from_u32_unchecked($m0, $m1, $m2, $m3)
}};
}
#[cfg(test)]
mod tests {
use num_traits::One;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use super::QM31;
use crate::core::fields::m31::P;
use crate::core::fields::{FieldExpOps, IntoSlice};
use crate::m31;
#[test]
fn test_inverse() {
let qm = qm31!(1, 2, 3, 4);
let qm_inv = qm.inverse();
assert_eq!(qm * qm_inv, QM31::one());
}
#[test]
fn test_ops() {
let qm0 = qm31!(1, 2, 3, 4);
let qm1 = qm31!(4, 5, 6, 7);
let m = m31!(8);
let qm = QM31::from(m);
let qm0_x_qm1 = qm31!(P - 71, 93, P - 16, 50);
assert_eq!(qm0 + qm1, qm31!(5, 7, 9, 11));
assert_eq!(qm1 + m, qm1 + qm);
assert_eq!(qm0 * qm1, qm0_x_qm1);
assert_eq!(qm1 * m, qm1 * qm);
assert_eq!(-qm0, qm31!(P - 1, P - 2, P - 3, P - 4));
assert_eq!(qm0 - qm1, qm31!(P - 3, P - 3, P - 3, P - 3));
assert_eq!(qm1 - m, qm1 - qm);
assert_eq!(qm0_x_qm1 / qm1, qm31!(1, 2, 3, 4));
assert_eq!(qm1 / m, qm1 / qm);
}
#[test]
fn test_into_slice() {
let mut rng = SmallRng::seed_from_u64(0);
let x = (0..100).map(|_| rng.gen()).collect::<Vec<QM31>>();
let slice = QM31::into_slice(&x);
for i in 0..100 {
let corresponding_sub_slice = &slice[i * 16..(i + 1) * 16];
assert_eq!(
x[i],
qm31!(
u32::from_le_bytes(corresponding_sub_slice[..4].try_into().unwrap()),
u32::from_le_bytes(corresponding_sub_slice[4..8].try_into().unwrap()),
u32::from_le_bytes(corresponding_sub_slice[8..12].try_into().unwrap()),
u32::from_le_bytes(corresponding_sub_slice[12..16].try_into().unwrap())
)
)
}
}
}

View File

@ -0,0 +1,111 @@
use std::array;
use std::iter::zip;
use super::m31::BaseField;
use super::qm31::SecureField;
use super::{ExtensionOf, FieldOps};
use crate::core::backend::{Col, Column, CpuBackend};
pub const SECURE_EXTENSION_DEGREE: usize =
<SecureField as ExtensionOf<BaseField>>::EXTENSION_DEGREE;
/// A column major array of `SECURE_EXTENSION_DEGREE` base field columns, that represents a column
/// of secure field element coordinates.
#[derive(Clone, Debug)]
pub struct SecureColumnByCoords<B: FieldOps<BaseField>> {
pub columns: [Col<B, BaseField>; SECURE_EXTENSION_DEGREE],
}
impl SecureColumnByCoords<CpuBackend> {
// TODO(spapini): Remove when we no longer use CircleEvaluation<SecureField>.
pub fn to_vec(&self) -> Vec<SecureField> {
(0..self.len()).map(|i| self.at(i)).collect()
}
}
impl<B: FieldOps<BaseField>> SecureColumnByCoords<B> {
pub fn at(&self, index: usize) -> SecureField {
SecureField::from_m31_array(std::array::from_fn(|i| self.columns[i].at(index)))
}
pub fn zeros(len: usize) -> Self {
Self {
columns: std::array::from_fn(|_| Col::<B, BaseField>::zeros(len)),
}
}
/// # Safety
pub unsafe fn uninitialized(len: usize) -> Self {
Self {
columns: std::array::from_fn(|_| Col::<B, BaseField>::uninitialized(len)),
}
}
pub fn len(&self) -> usize {
self.columns[0].len()
}
pub fn is_empty(&self) -> bool {
self.columns[0].is_empty()
}
pub fn to_cpu(&self) -> SecureColumnByCoords<CpuBackend> {
SecureColumnByCoords {
columns: self.columns.clone().map(|c| c.to_cpu()),
}
}
pub fn set(&mut self, index: usize, value: SecureField) {
let values = value.to_m31_array();
#[allow(clippy::needless_range_loop)]
for i in 0..SECURE_EXTENSION_DEGREE {
self.columns[i].set(index, values[i]);
}
}
}
pub struct SecureColumnByCoordsIter<'a> {
column: &'a SecureColumnByCoords<CpuBackend>,
index: usize,
}
impl Iterator for SecureColumnByCoordsIter<'_> {
type Item = SecureField;
fn next(&mut self) -> Option<Self::Item> {
if self.index < self.column.len() {
let value = self.column.at(self.index);
self.index += 1;
Some(value)
} else {
None
}
}
}
impl<'a> IntoIterator for &'a SecureColumnByCoords<CpuBackend> {
type Item = SecureField;
type IntoIter = SecureColumnByCoordsIter<'a>;
fn into_iter(self) -> Self::IntoIter {
SecureColumnByCoordsIter {
column: self,
index: 0,
}
}
}
impl FromIterator<SecureField> for SecureColumnByCoords<CpuBackend> {
fn from_iter<I: IntoIterator<Item = SecureField>>(iter: I) -> Self {
let values = iter.into_iter();
let (lower_bound, _) = values.size_hint();
let mut columns = array::from_fn(|_| Vec::with_capacity(lower_bound));
for value in values {
let coords = value.to_m31_array();
zip(&mut columns, coords).for_each(|(col, coord)| col.push(coord));
}
SecureColumnByCoords { columns }
}
}
impl From<SecureColumnByCoords<CpuBackend>> for Vec<SecureField> {
fn from(column: SecureColumnByCoords<CpuBackend>) -> Self {
column.into_iter().collect()
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,566 @@
//! GKR batch prover for Grand Product and LogUp lookup arguments.
use std::borrow::Cow;
use std::iter::{successors, zip};
use std::ops::Deref;
use educe::Educe;
use itertools::Itertools;
use num_traits::{One, Zero};
use thiserror::Error;
use super::gkr_verifier::{GkrArtifact, GkrBatchProof, GkrMask};
use super::mle::{Mle, MleOps};
use super::sumcheck::MultivariatePolyOracle;
use super::utils::{eq, random_linear_combination, UnivariatePoly};
use crate::core::backend::{Col, Column, ColumnOps, CpuBackend};
use crate::core::channel::Channel;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::{Field, FieldExpOps};
use crate::core::lookups::sumcheck;
pub trait GkrOps: MleOps<BaseField> + MleOps<SecureField> {
/// Returns evaluations `eq(x, y) * v` for all `x` in `{0, 1}^n`.
///
/// Note [`Mle`] stores values in bit-reversed order.
///
/// [`eq(x, y)`]: crate::core::lookups::utils::eq
fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Mle<Self, SecureField>;
/// Generates the next GKR layer from the current one.
fn next_layer(layer: &Layer<Self>) -> Layer<Self>;
/// Returns univariate polynomial `f(t) = sum_x h(t, x)` for all `x` in the boolean hypercube.
///
/// `claim` equals `f(0) + f(1)`.
///
/// For more context see docs of [`MultivariatePolyOracle::sum_as_poly_in_first_variable()`].
fn sum_as_poly_in_first_variable(
h: &GkrMultivariatePolyOracle<'_, Self>,
claim: SecureField,
) -> UnivariatePoly<SecureField>;
}
/// Stores evaluations of [`eq(x, y)`] on all boolean hypercube points of the form
/// `x = (0, x_1, ..., x_{n-1})`.
///
/// Evaluations are stored in bit-reversed order i.e. `evals[0] = eq((0, ..., 0, 0), y)`,
/// `evals[1] = eq((0, ..., 0, 1), y)`, etc.
///
/// [`eq(x, y)`]: crate::core::lookups::utils::eq
#[derive(Educe)]
#[educe(Debug, Clone)]
pub struct EqEvals<B: ColumnOps<SecureField>> {
y: Vec<SecureField>,
evals: Mle<B, SecureField>,
}
impl<B: GkrOps> EqEvals<B> {
pub fn generate(y: &[SecureField]) -> Self {
let y = y.to_vec();
if y.is_empty() {
let evals = Mle::new([SecureField::one()].into_iter().collect());
return Self { evals, y };
}
let evals = B::gen_eq_evals(&y[1..], eq(&[SecureField::zero()], &[y[0]]));
assert_eq!(evals.len(), 1 << (y.len() - 1));
Self { evals, y }
}
/// Returns the fixed vector `y` used to generate the evaluations.
pub fn y(&self) -> &[SecureField] {
&self.y
}
}
impl<B: ColumnOps<SecureField>> Deref for EqEvals<B> {
type Target = Col<B, SecureField>;
fn deref(&self) -> &Col<B, SecureField> {
&self.evals
}
}
/// Represents a layer in a binary tree structured GKR circuit.
///
/// Layers can contain multiple columns, for example [LogUp] which has separate columns for
/// numerators and denominators.
///
/// [LogUp]: https://eprint.iacr.org/2023/1284.pdf
#[derive(Educe)]
#[educe(Debug, Clone)]
pub enum Layer<B: GkrOps> {
GrandProduct(Mle<B, SecureField>),
LogUpGeneric {
numerators: Mle<B, SecureField>,
denominators: Mle<B, SecureField>,
},
LogUpMultiplicities {
numerators: Mle<B, BaseField>,
denominators: Mle<B, SecureField>,
},
/// All numerators implicitly equal "1".
LogUpSingles {
denominators: Mle<B, SecureField>,
},
}
impl<B: GkrOps> Layer<B> {
/// Returns the number of variables used to interpolate the layer's gate values.
pub fn n_variables(&self) -> usize {
match self {
Self::GrandProduct(mle)
| Self::LogUpSingles { denominators: mle }
| Self::LogUpMultiplicities {
denominators: mle, ..
}
| Self::LogUpGeneric {
denominators: mle, ..
} => mle.n_variables(),
}
}
fn is_output_layer(&self) -> bool {
self.n_variables() == 0
}
/// Produces the next layer from the current layer.
///
/// The next layer is strictly half the size of the current layer.
/// Returns [`None`] if called on an output layer.
pub fn next_layer(&self) -> Option<Self> {
if self.is_output_layer() {
return None;
}
Some(B::next_layer(self))
}
/// Returns each column output if the layer is an output layer, otherwise returns an `Err`.
fn try_into_output_layer_values(self) -> Result<Vec<SecureField>, NotOutputLayerError> {
if !self.is_output_layer() {
return Err(NotOutputLayerError);
}
Ok(match self {
Layer::LogUpSingles { denominators } => {
let numerator = SecureField::one();
let denominator = denominators.at(0);
vec![numerator, denominator]
}
Layer::LogUpMultiplicities {
numerators,
denominators,
} => {
let numerator = numerators.at(0).into();
let denominator = denominators.at(0);
vec![numerator, denominator]
}
Layer::LogUpGeneric {
numerators,
denominators,
} => {
let numerator = numerators.at(0);
let denominator = denominators.at(0);
vec![numerator, denominator]
}
Layer::GrandProduct(col) => {
vec![col.at(0)]
}
})
}
/// Returns a transformed layer with the first variable of each column fixed to `assignment`.
fn fix_first_variable(self, x0: SecureField) -> Self {
if self.n_variables() == 0 {
return self;
}
match self {
Self::GrandProduct(mle) => Self::GrandProduct(mle.fix_first_variable(x0)),
Self::LogUpGeneric {
numerators,
denominators,
} => Self::LogUpGeneric {
numerators: numerators.fix_first_variable(x0),
denominators: denominators.fix_first_variable(x0),
},
Self::LogUpMultiplicities {
numerators,
denominators,
} => Self::LogUpGeneric {
numerators: numerators.fix_first_variable(x0),
denominators: denominators.fix_first_variable(x0),
},
Self::LogUpSingles { denominators } => Self::LogUpSingles {
denominators: denominators.fix_first_variable(x0),
},
}
}
/// Represents the next GKR layer evaluation as a multivariate polynomial which uses this GKR
/// layer as input.
///
/// Layers can contain multiple columns `c_0, ..., c_{n-1}` with multivariate polynomial `g_i`
/// representing[^note] `c_i` in the next layer. These polynomials must be combined with
/// `lambda` into a single polynomial `h = g_0 + lambda * g_1 + ... + lambda^(n-1) *
/// g_{n-1}`. The oracle for `h` should be returned.
///
/// # Optimization: precomputed [`eq(x, y)`] evals
///
/// Let `y` be a fixed vector of length `m` and let `z` be a subvector comprising of the
/// last `k` elements of `y`. `h(x)` **must** equal some multivariate polynomial of the form
/// `eq(x, z) * p(x)`. A common operation will be computing the univariate polynomial `f(t) =
/// sum_x h(t, x)` for `x` in the boolean hypercube `{0, 1}^(k-1)`.
///
/// `eq_evals` stores evaluations of `eq((0, x), y)` for `x` in a potentially extended boolean
/// hypercube `{0, 1}^{m-1}`. These evaluations, on the extended hypercube, can be used directly
/// in computing the sums of `h(x)`, however a correction factor must be applied to the final
/// sum which is handled by [`correct_sum_as_poly_in_first_variable()`] in `O(m)`.
///
/// Being able to compute sums of `h(x)` using `eq_evals` in this way leads to a more efficient
/// implementation because the prover only has to generate `eq_evals` once for an entire batch
/// of multiple GKR layer instances.
///
/// [`eq(x, y)`]: crate::core::lookups::utils::eq
/// [^note]: By "representing" we mean `g_i` agrees with the next layer's `c_i` on the boolean
/// hypercube that interpolates `c_i`.
fn into_multivariate_poly(
self,
lambda: SecureField,
eq_evals: &EqEvals<B>,
) -> GkrMultivariatePolyOracle<'_, B> {
GkrMultivariatePolyOracle {
eq_evals: Cow::Borrowed(eq_evals),
input_layer: self,
eq_fixed_var_correction: SecureField::one(),
lambda,
}
}
/// Returns a copy of this layer with the [`CpuBackend`].
///
/// This operation is expensive but can be useful for small traces that are difficult to handle
/// depending on the backend. For example, the SIMD backend offloads to the CPU backend when
/// trace length becomes smaller than the SIMD lane count.
pub fn to_cpu(&self) -> Layer<CpuBackend> {
match self {
Layer::GrandProduct(mle) => Layer::GrandProduct(Mle::new(mle.to_cpu())),
Layer::LogUpGeneric {
numerators,
denominators,
} => Layer::LogUpGeneric {
numerators: Mle::new(numerators.to_cpu()),
denominators: Mle::new(denominators.to_cpu()),
},
Layer::LogUpMultiplicities {
numerators,
denominators,
} => Layer::LogUpMultiplicities {
numerators: Mle::new(numerators.to_cpu()),
denominators: Mle::new(denominators.to_cpu()),
},
Layer::LogUpSingles { denominators } => Layer::LogUpSingles {
denominators: Mle::new(denominators.to_cpu()),
},
}
}
}
#[derive(Debug)]
struct NotOutputLayerError;
/// Multivariate polynomial `P` that expresses the relation between two consecutive GKR layers.
///
/// When the input layer is [`Layer::GrandProduct`] (represented by multilinear column `inp`)
/// the polynomial represents:
///
/// ```text
/// P(x) = eq(x, y) * inp(x, 0) * inp(x, 1)
/// ```
///
/// When the input layer is LogUp (represented by multilinear columns `inp_numer` and
/// `inp_denom`) the polynomial represents:
///
/// ```text
/// numer(x) = inp_numer(x, 0) * inp_denom(x, 1) + inp_numer(x, 1) * inp_denom(x, 0)
/// denom(x) = inp_denom(x, 0) * inp_denom(x, 1)
///
/// P(x) = eq(x, y) * (numer(x) + lambda * denom(x))
/// ```
pub struct GkrMultivariatePolyOracle<'a, B: GkrOps> {
/// `eq_evals` passed by `Layer::into_multivariate_poly()`.
pub eq_evals: Cow<'a, EqEvals<B>>,
pub input_layer: Layer<B>,
pub eq_fixed_var_correction: SecureField,
/// Used by LogUp to perform a random linear combination of the numerators and denominators.
pub lambda: SecureField,
}
impl<'a, B: GkrOps> MultivariatePolyOracle for GkrMultivariatePolyOracle<'a, B> {
fn n_variables(&self) -> usize {
self.input_layer.n_variables() - 1
}
fn sum_as_poly_in_first_variable(&self, claim: SecureField) -> UnivariatePoly<SecureField> {
B::sum_as_poly_in_first_variable(self, claim)
}
fn fix_first_variable(self, challenge: SecureField) -> Self {
if self.is_constant() {
return self;
}
let z0 = self.eq_evals.y()[self.eq_evals.y().len() - self.n_variables()];
let eq_fixed_var_correction = self.eq_fixed_var_correction * eq(&[challenge], &[z0]);
Self {
eq_evals: self.eq_evals,
eq_fixed_var_correction,
input_layer: self.input_layer.fix_first_variable(challenge),
lambda: self.lambda,
}
}
}
impl<'a, B: GkrOps> GkrMultivariatePolyOracle<'a, B> {
fn is_constant(&self) -> bool {
self.n_variables() == 0
}
/// Returns all input layer columns restricted to a line.
///
/// Let `l` be the line satisfying `l(0) = b*` and `l(1) = c*`. Oracles that represent constants
/// are expressed by values `c_i(b*)` and `c_i(c*)` where `c_i` represents the input GKR layer's
/// `i`th column (for binary tree GKR `b* = (r, 0)`, `c* = (r, 1)`).
///
/// If this oracle represents a constant, then each `c_i` restricted to `l` is returned.
/// Otherwise, an [`Err`] is returned.
///
/// For more context see <https://people.cs.georgetown.edu/jthaler/ProofsArgsAndZK.pdf> page 64.
fn try_into_mask(self) -> Result<GkrMask, NotConstantPolyError> {
if !self.is_constant() {
return Err(NotConstantPolyError);
}
let columns = match self.input_layer {
Layer::GrandProduct(mle) => vec![mle.to_cpu().try_into().unwrap()],
Layer::LogUpGeneric {
numerators,
denominators,
} => {
let numerators = numerators.to_cpu().try_into().unwrap();
let denominators = denominators.to_cpu().try_into().unwrap();
vec![numerators, denominators]
}
// Should never get called.
Layer::LogUpMultiplicities { .. } => unimplemented!(),
Layer::LogUpSingles { denominators } => {
let numerators = [SecureField::one(); 2];
let denominators = denominators.to_cpu().try_into().unwrap();
vec![numerators, denominators]
}
};
Ok(GkrMask::new(columns))
}
/// Returns a copy of this oracle with the [`CpuBackend`].
///
/// This operation is expensive but can be useful for small oracles that are difficult to handle
/// depending on the backend. For example, the SIMD backend offloads to the CPU backend when
/// trace length becomes smaller than the SIMD lane count.
pub fn to_cpu(&self) -> GkrMultivariatePolyOracle<'a, CpuBackend> {
// TODO(andrew): This block is not ideal.
let n_eq_evals = 1 << (self.n_variables() - 1);
let eq_evals = Cow::Owned(EqEvals {
evals: Mle::new((0..n_eq_evals).map(|i| self.eq_evals.at(i)).collect()),
y: self.eq_evals.y.to_vec(),
});
GkrMultivariatePolyOracle {
eq_evals,
eq_fixed_var_correction: self.eq_fixed_var_correction,
input_layer: self.input_layer.to_cpu(),
lambda: self.lambda,
}
}
}
/// Error returned when a polynomial is expected to be constant but it is not.
#[derive(Debug, Error)]
#[error("polynomial is not constant")]
pub struct NotConstantPolyError;
/// Batch proves lookup circuits with GKR.
///
/// The input layers should be committed to the channel before calling this function.
// GKR algorithm: <https://people.cs.georgetown.edu/jthaler/ProofsArgsAndZK.pdf> (page 64)
pub fn prove_batch<B: GkrOps>(
channel: &mut impl Channel,
input_layer_by_instance: Vec<Layer<B>>,
) -> (GkrBatchProof, GkrArtifact) {
let n_instances = input_layer_by_instance.len();
let n_layers_by_instance = input_layer_by_instance
.iter()
.map(|l| l.n_variables())
.collect_vec();
let n_layers = *n_layers_by_instance.iter().max().unwrap();
// Evaluate all instance circuits and collect the layer values.
let mut layers_by_instance = input_layer_by_instance
.into_iter()
.map(|input_layer| gen_layers(input_layer).into_iter().rev())
.collect_vec();
let mut output_claims_by_instance = vec![None; n_instances];
let mut layer_masks_by_instance = (0..n_instances).map(|_| Vec::new()).collect_vec();
let mut sumcheck_proofs = Vec::new();
let mut ood_point = Vec::new();
let mut claims_to_verify_by_instance = vec![None; n_instances];
for layer in 0..n_layers {
let n_remaining_layers = n_layers - layer;
// Check all the instances for output layers.
for (instance, layers) in layers_by_instance.iter_mut().enumerate() {
if n_layers_by_instance[instance] == n_remaining_layers {
let output_layer = layers.next().unwrap();
let output_layer_values = output_layer.try_into_output_layer_values().unwrap();
claims_to_verify_by_instance[instance] = Some(output_layer_values.clone());
output_claims_by_instance[instance] = Some(output_layer_values);
}
}
// Seed the channel with layer claims.
for claims_to_verify in claims_to_verify_by_instance.iter().flatten() {
channel.mix_felts(claims_to_verify);
}
let eq_evals = EqEvals::generate(&ood_point);
let sumcheck_alpha = channel.draw_felt();
let instance_lambda = channel.draw_felt();
let mut sumcheck_oracles = Vec::new();
let mut sumcheck_claims = Vec::new();
let mut sumcheck_instances = Vec::new();
// Create the multivariate polynomial oracles used with sumcheck.
for (instance, claims_to_verify) in claims_to_verify_by_instance.iter().enumerate() {
if let Some(claims_to_verify) = claims_to_verify {
let layer = layers_by_instance[instance].next().unwrap();
sumcheck_oracles.push(layer.into_multivariate_poly(instance_lambda, &eq_evals));
sumcheck_claims.push(random_linear_combination(claims_to_verify, instance_lambda));
sumcheck_instances.push(instance);
}
}
let (sumcheck_proof, sumcheck_ood_point, constant_poly_oracles, _) =
sumcheck::prove_batch(sumcheck_claims, sumcheck_oracles, sumcheck_alpha, channel);
sumcheck_proofs.push(sumcheck_proof);
let masks = constant_poly_oracles
.into_iter()
.map(|oracle| oracle.try_into_mask().unwrap())
.collect_vec();
// Seed the channel with the layer masks.
for (&instance, mask) in zip(&sumcheck_instances, &masks) {
channel.mix_felts(mask.columns().flatten());
layer_masks_by_instance[instance].push(mask.clone());
}
let challenge = channel.draw_felt();
ood_point = sumcheck_ood_point;
ood_point.push(challenge);
// Set the claims to prove in the layer above.
for (instance, mask) in zip(sumcheck_instances, masks) {
claims_to_verify_by_instance[instance] = Some(mask.reduce_at_point(challenge));
}
}
let output_claims_by_instance = output_claims_by_instance
.into_iter()
.map(Option::unwrap)
.collect();
let claims_to_verify_by_instance = claims_to_verify_by_instance
.into_iter()
.map(Option::unwrap)
.collect();
let proof = GkrBatchProof {
sumcheck_proofs,
layer_masks_by_instance,
output_claims_by_instance,
};
let artifact = GkrArtifact {
ood_point,
claims_to_verify_by_instance,
n_variables_by_instance: n_layers_by_instance,
};
(proof, artifact)
}
/// Executes the GKR circuit on the input layer and returns all the circuit's layers.
fn gen_layers<B: GkrOps>(input_layer: Layer<B>) -> Vec<Layer<B>> {
let n_variables = input_layer.n_variables();
let layers = successors(Some(input_layer), |layer| layer.next_layer()).collect_vec();
assert_eq!(layers.len(), n_variables + 1);
layers
}
/// Computes `r(t) = sum_x eq((t, x), y[-k:]) * p(t, x)` from evaluations of
/// `f(t) = sum_x eq(({0}^(n - k), 0, x), y) * p(t, x)`.
///
/// Note `claim` must equal `r(0) + r(1)` and `r` must have degree <= 3.
///
/// For more context see `Layer::into_multivariate_poly()` docs.
/// See also <https://ia.cr/2024/108> (section 3.2).
pub fn correct_sum_as_poly_in_first_variable(
f_at_0: SecureField,
f_at_2: SecureField,
claim: SecureField,
y: &[SecureField],
k: usize,
) -> UnivariatePoly<SecureField> {
assert_ne!(k, 0);
let n = y.len();
assert!(k <= n);
// We evaluated `f(0)` and `f(2)` - the inputs.
// We want to compute `r(t) = f(t) * eq(t, y[n - k]) / eq(0, y[:n - k + 1])`.
let a_const = eq(&vec![SecureField::zero(); n - k + 1], &y[..n - k + 1]).inverse();
// Find the additional root of `r(t)`, by finding the root of `eq(t, y[n - k])`:
// 0 = eq(t, y[n - k])
// = t * y[n - k] + (1 - t)(1 - y[n - k])
// = 1 - y[n - k] - t(1 - 2 * y[n - k])
// => t = (1 - y[n - k]) / (1 - 2 * y[n - k])
// = b
let b_const = (SecureField::one() - y[n - k]) / (SecureField::one() - y[n - k].double());
// We get that `r(t) = f(t) * eq(t, y[n - k]) * a`.
let r_at_0 = f_at_0 * eq(&[SecureField::zero()], &[y[n - k]]) * a_const;
let r_at_1 = claim - r_at_0;
let r_at_2 = f_at_2 * eq(&[BaseField::from(2).into()], &[y[n - k]]) * a_const;
let r_at_b = SecureField::zero();
// Interpolate.
UnivariatePoly::interpolate_lagrange(
&[
SecureField::zero(),
SecureField::one(),
SecureField::from(BaseField::from(2)),
b_const,
],
&[r_at_0, r_at_1, r_at_2, r_at_b],
)
}

View File

@ -0,0 +1,357 @@
//! GKR batch verifier for Grand Product and LogUp lookup arguments.
use thiserror::Error;
use super::sumcheck::{SumcheckError, SumcheckProof};
use super::utils::{eq, fold_mle_evals, random_linear_combination};
use crate::core::channel::Channel;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::sumcheck;
use crate::core::lookups::utils::Fraction;
/// Partially verifies a batch GKR proof.
///
/// On successful verification the function returns a [`GkrArtifact`] which stores the out-of-domain
/// point and claimed evaluations in the input layer columns for each instance at the OOD point.
/// These claimed evaluations are not checked in this function - hence partial verification.
pub fn partially_verify_batch(
gate_by_instance: Vec<Gate>,
proof: &GkrBatchProof,
channel: &mut impl Channel,
) -> Result<GkrArtifact, GkrError> {
let GkrBatchProof {
sumcheck_proofs,
layer_masks_by_instance,
output_claims_by_instance,
} = proof;
if layer_masks_by_instance.len() != output_claims_by_instance.len() {
return Err(GkrError::MalformedProof);
}
let n_instances = layer_masks_by_instance.len();
let instance_n_layers = |instance: usize| layer_masks_by_instance[instance].len();
let n_layers = (0..n_instances).map(instance_n_layers).max().unwrap();
if n_layers != sumcheck_proofs.len() {
return Err(GkrError::MalformedProof);
}
if gate_by_instance.len() != n_instances {
return Err(GkrError::NumInstancesMismatch {
given: gate_by_instance.len(),
proof: n_instances,
});
}
let mut ood_point = vec![];
let mut claims_to_verify_by_instance = vec![None; n_instances];
for (layer, sumcheck_proof) in sumcheck_proofs.iter().enumerate() {
let n_remaining_layers = n_layers - layer;
// Check for output layers.
for instance in 0..n_instances {
if instance_n_layers(instance) == n_remaining_layers {
let output_claims = output_claims_by_instance[instance].clone();
claims_to_verify_by_instance[instance] = Some(output_claims);
}
}
// Seed the channel with layer claims.
for claims_to_verify in claims_to_verify_by_instance.iter().flatten() {
channel.mix_felts(claims_to_verify);
}
let sumcheck_alpha = channel.draw_felt();
let instance_lambda = channel.draw_felt();
let mut sumcheck_claims = Vec::new();
let mut sumcheck_instances = Vec::new();
// Prepare the sumcheck claim.
for (instance, claims_to_verify) in claims_to_verify_by_instance.iter().enumerate() {
if let Some(claims_to_verify) = claims_to_verify {
let n_unused_variables = n_layers - instance_n_layers(instance);
let doubling_factor = BaseField::from(1 << n_unused_variables);
let claim =
random_linear_combination(claims_to_verify, instance_lambda) * doubling_factor;
sumcheck_claims.push(claim);
sumcheck_instances.push(instance);
}
}
let sumcheck_claim = random_linear_combination(&sumcheck_claims, sumcheck_alpha);
let (sumcheck_ood_point, sumcheck_eval) =
sumcheck::partially_verify(sumcheck_claim, sumcheck_proof, channel)
.map_err(|source| GkrError::InvalidSumcheck { layer, source })?;
let mut layer_evals = Vec::new();
// Evaluate the circuit locally at sumcheck OOD point.
for &instance in &sumcheck_instances {
let n_unused = n_layers - instance_n_layers(instance);
let mask = &layer_masks_by_instance[instance][layer - n_unused];
let gate = &gate_by_instance[instance];
let gate_output = gate.eval(mask).map_err(|InvalidNumMaskColumnsError| {
let instance_layer = instance_n_layers(layer) - n_remaining_layers;
GkrError::InvalidMask {
instance,
instance_layer,
}
})?;
// TODO: Consider simplifying the code by just using the same eq eval for all instances
// regardless of size.
let eq_eval = eq(&ood_point[n_unused..], &sumcheck_ood_point[n_unused..]);
layer_evals.push(eq_eval * random_linear_combination(&gate_output, instance_lambda));
}
let layer_eval = random_linear_combination(&layer_evals, sumcheck_alpha);
if sumcheck_eval != layer_eval {
return Err(GkrError::CircuitCheckFailure {
claim: sumcheck_eval,
output: layer_eval,
layer,
});
}
// Seed the channel with the layer masks.
for &instance in &sumcheck_instances {
let n_unused = n_layers - instance_n_layers(instance);
let mask = &layer_masks_by_instance[instance][layer - n_unused];
channel.mix_felts(mask.columns().flatten());
}
// Set the OOD evaluation point for layer above.
let challenge = channel.draw_felt();
ood_point = sumcheck_ood_point;
ood_point.push(challenge);
// Set the claims to verify in the layer above.
for instance in sumcheck_instances {
let n_unused = n_layers - instance_n_layers(instance);
let mask = &layer_masks_by_instance[instance][layer - n_unused];
claims_to_verify_by_instance[instance] = Some(mask.reduce_at_point(challenge));
}
}
let claims_to_verify_by_instance = claims_to_verify_by_instance
.into_iter()
.map(Option::unwrap)
.collect();
Ok(GkrArtifact {
ood_point,
claims_to_verify_by_instance,
n_variables_by_instance: (0..n_instances).map(instance_n_layers).collect(),
})
}
/// Batch GKR proof.
pub struct GkrBatchProof {
/// Sum-check proof for each layer.
pub sumcheck_proofs: Vec<SumcheckProof>,
/// Mask for each layer for each instance.
pub layer_masks_by_instance: Vec<Vec<GkrMask>>,
/// Column circuit outputs for each instance.
pub output_claims_by_instance: Vec<Vec<SecureField>>,
}
/// Values of interest obtained from the execution of the GKR protocol.
pub struct GkrArtifact {
/// Out-of-domain (OOD) point for evaluating columns in the input layer.
pub ood_point: Vec<SecureField>,
/// The claimed evaluation at `ood_point` for each column in the input layer of each instance.
pub claims_to_verify_by_instance: Vec<Vec<SecureField>>,
/// The number of variables that interpolate the input layer of each instance.
pub n_variables_by_instance: Vec<usize>,
}
/// Defines how a circuit operates locally on two input rows to produce a single output row.
/// This local 2-to-1 constraint is what gives the whole circuit its "binary tree" structure.
///
/// Binary tree structured circuits have a highly regular wiring pattern that fit the structure of
/// the circuits defined in [Thaler13] which allow for efficient linear time (linear in size of the
/// circuit) GKR prover implementations.
///
/// [Thaler13]: https://eprint.iacr.org/2013/351.pdf
#[derive(Debug, Clone, Copy)]
pub enum Gate {
LogUp,
GrandProduct,
}
impl Gate {
/// Returns the output after applying the gate to the mask.
fn eval(&self, mask: &GkrMask) -> Result<Vec<SecureField>, InvalidNumMaskColumnsError> {
Ok(match self {
Self::LogUp => {
if mask.columns().len() != 2 {
return Err(InvalidNumMaskColumnsError);
}
let [numerator_a, numerator_b] = mask.columns()[0];
let [denominator_a, denominator_b] = mask.columns()[1];
let a = Fraction::new(numerator_a, denominator_a);
let b = Fraction::new(numerator_b, denominator_b);
let res = a + b;
vec![res.numerator, res.denominator]
}
Self::GrandProduct => {
if mask.columns().len() != 1 {
return Err(InvalidNumMaskColumnsError);
}
let [a, b] = mask.columns()[0];
vec![a * b]
}
})
}
}
/// Mask has an invalid number of columns
#[derive(Debug)]
struct InvalidNumMaskColumnsError;
/// Stores two evaluations of each column in a GKR layer.
#[derive(Debug, Clone)]
pub struct GkrMask {
columns: Vec<[SecureField; 2]>,
}
impl GkrMask {
pub fn new(columns: Vec<[SecureField; 2]>) -> Self {
Self { columns }
}
pub fn to_rows(&self) -> [Vec<SecureField>; 2] {
self.columns.iter().map(|[a, b]| (a, b)).unzip().into()
}
pub fn columns(&self) -> &[[SecureField; 2]] {
&self.columns
}
/// Returns all `p_i(x)` where `p_i` interpolates column `i` of the mask on `{0, 1}`.
pub fn reduce_at_point(&self, x: SecureField) -> Vec<SecureField> {
self.columns
.iter()
.map(|&[v0, v1]| fold_mle_evals(x, v0, v1))
.collect()
}
}
/// Error encountered during GKR protocol verification.
#[derive(Error, Debug)]
pub enum GkrError {
/// The proof is malformed.
#[error("proof data is invalid")]
MalformedProof,
/// Mask has an invalid number of columns.
#[error("mask in layer {instance_layer} of instance {instance} is invalid")]
InvalidMask {
instance: usize,
/// Layer of the instance (but not necessarily the batch).
instance_layer: LayerIndex,
},
/// There is a mismatch between the number of instances in the proof and the number of
/// instances passed for verification.
#[error("provided an invalid number of instances (given {given}, proof expects {proof})")]
NumInstancesMismatch { given: usize, proof: usize },
/// There was an error with one of the sumcheck proofs.
#[error("sum-check invalid in layer {layer}: {source}")]
InvalidSumcheck {
layer: LayerIndex,
source: SumcheckError,
},
/// The circuit polynomial the verifier evaluated doesn't match claim from sumcheck.
#[error("circuit check failed in layer {layer} (calculated {output}, claim {claim})")]
CircuitCheckFailure {
claim: SecureField,
output: SecureField,
layer: LayerIndex,
},
}
/// GKR layer index where 0 corresponds to the output layer.
pub type LayerIndex = usize;
#[cfg(test)]
mod tests {
use super::{partially_verify_batch, Gate, GkrArtifact, GkrError};
use crate::core::backend::CpuBackend;
use crate::core::channel::Channel;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::gkr_prover::{prove_batch, Layer};
use crate::core::lookups::mle::Mle;
use crate::core::test_utils::test_channel;
#[test]
fn prove_batch_works() -> Result<(), GkrError> {
const LOG_N: usize = 5;
let mut channel = test_channel();
let col0 = Mle::<CpuBackend, SecureField>::new(channel.draw_felts(1 << LOG_N));
let col1 = Mle::<CpuBackend, SecureField>::new(channel.draw_felts(1 << LOG_N));
let product0 = col0.iter().product::<SecureField>();
let product1 = col1.iter().product::<SecureField>();
let input_layers = vec![
Layer::GrandProduct(col0.clone()),
Layer::GrandProduct(col1.clone()),
];
let (proof, _) = prove_batch(&mut test_channel(), input_layers);
let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
n_variables_by_instance,
} = partially_verify_batch(vec![Gate::GrandProduct; 2], &proof, &mut test_channel())?;
assert_eq!(n_variables_by_instance, [LOG_N, LOG_N]);
assert_eq!(proof.output_claims_by_instance.len(), 2);
assert_eq!(claims_to_verify_by_instance.len(), 2);
assert_eq!(proof.output_claims_by_instance[0], &[product0]);
assert_eq!(proof.output_claims_by_instance[1], &[product1]);
let claim0 = &claims_to_verify_by_instance[0];
let claim1 = &claims_to_verify_by_instance[1];
assert_eq!(claim0, &[col0.eval_at_point(&ood_point)]);
assert_eq!(claim1, &[col1.eval_at_point(&ood_point)]);
Ok(())
}
#[test]
fn prove_batch_with_different_sizes_works() -> Result<(), GkrError> {
const LOG_N0: usize = 5;
const LOG_N1: usize = 7;
let mut channel = test_channel();
let col0 = Mle::<CpuBackend, SecureField>::new(channel.draw_felts(1 << LOG_N0));
let col1 = Mle::<CpuBackend, SecureField>::new(channel.draw_felts(1 << LOG_N1));
let product0 = col0.iter().product::<SecureField>();
let product1 = col1.iter().product::<SecureField>();
let input_layers = vec![
Layer::GrandProduct(col0.clone()),
Layer::GrandProduct(col1.clone()),
];
let (proof, _) = prove_batch(&mut test_channel(), input_layers);
let GkrArtifact {
ood_point,
claims_to_verify_by_instance,
n_variables_by_instance,
} = partially_verify_batch(vec![Gate::GrandProduct; 2], &proof, &mut test_channel())?;
assert_eq!(n_variables_by_instance, [LOG_N0, LOG_N1]);
assert_eq!(proof.output_claims_by_instance.len(), 2);
assert_eq!(claims_to_verify_by_instance.len(), 2);
assert_eq!(proof.output_claims_by_instance[0], &[product0]);
assert_eq!(proof.output_claims_by_instance[1], &[product1]);
let claim0 = &claims_to_verify_by_instance[0];
let claim1 = &claims_to_verify_by_instance[1];
let n_vars = ood_point.len();
assert_eq!(claim0, &[col0.eval_at_point(&ood_point[n_vars - LOG_N0..])]);
assert_eq!(claim1, &[col1.eval_at_point(&ood_point[n_vars - LOG_N1..])]);
Ok(())
}
}

View File

@ -0,0 +1,106 @@
use std::ops::{Deref, DerefMut};
use educe::Educe;
use crate::core::backend::{Col, Column, ColumnOps};
use crate::core::fields::qm31::SecureField;
use crate::core::fields::Field;
pub trait MleOps<F: Field>: ColumnOps<F> + Sized {
/// Returns a transformed [`Mle`] where the first variable is fixed to `assignment`.
fn fix_first_variable(mle: Mle<Self, F>, assignment: SecureField) -> Mle<Self, SecureField>
where
Self: MleOps<SecureField>;
}
/// Multilinear Extension stored as evaluations of a multilinear polynomial over the boolean
/// hypercube in bit-reversed order.
#[derive(Educe)]
#[educe(Debug, Clone)]
pub struct Mle<B: ColumnOps<F>, F: Field> {
evals: Col<B, F>,
}
impl<B: MleOps<F>, F: Field> Mle<B, F> {
/// Creates a [`Mle`] from evaluations of a multilinear polynomial on the boolean hypercube.
///
/// # Panics
///
/// Panics if the number of evaluations is not a power of two.
pub fn new(evals: Col<B, F>) -> Self {
assert!(evals.len().is_power_of_two());
Self { evals }
}
pub fn into_evals(self) -> Col<B, F> {
self.evals
}
/// Returns a transformed polynomial where the first variable is fixed to `assignment`.
pub fn fix_first_variable(self, assignment: SecureField) -> Mle<B, SecureField>
where
B: MleOps<SecureField>,
{
B::fix_first_variable(self, assignment)
}
/// Returns the number of variables in the polynomial.
pub fn n_variables(&self) -> usize {
self.evals.len().ilog2() as usize
}
}
impl<B: ColumnOps<F>, F: Field> Deref for Mle<B, F> {
type Target = Col<B, F>;
fn deref(&self) -> &Col<B, F> {
&self.evals
}
}
impl<B: ColumnOps<F>, F: Field> DerefMut for Mle<B, F> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.evals
}
}
#[cfg(test)]
mod test {
use super::{Mle, MleOps};
use crate::core::backend::Column;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::{ExtensionOf, Field};
impl<B, F> Mle<B, F>
where
F: Field,
SecureField: ExtensionOf<F>,
B: MleOps<F>,
{
/// Evaluates the multilinear polynomial at `point`.
pub(crate) fn eval_at_point(&self, point: &[SecureField]) -> SecureField {
pub fn eval(mle_evals: &[SecureField], p: &[SecureField]) -> SecureField {
match p {
[] => mle_evals[0],
&[p_i, ref p @ ..] => {
let (lhs, rhs) = mle_evals.split_at(mle_evals.len() / 2);
let lhs_eval = eval(lhs, p);
let rhs_eval = eval(rhs, p);
// Equivalent to `eq(0, p_i) * lhs_eval + eq(1, p_i) * rhs_eval`.
p_i * (rhs_eval - lhs_eval) + lhs_eval
}
}
}
let mle_evals = self
.clone()
.into_evals()
.to_cpu()
.into_iter()
.map(|v| v.into())
.collect::<Vec<_>>();
eval(&mle_evals, point)
}
}
}

View File

@ -0,0 +1,5 @@
pub mod gkr_prover;
pub mod gkr_verifier;
pub mod mle;
pub mod sumcheck;
pub mod utils;

View File

@ -0,0 +1,292 @@
//! Sum-check protocol that proves and verifies claims about `sum_x g(x)` for all x in `{0, 1}^n`.
//!
//! [`MultivariatePolyOracle`] provides methods for evaluating sums and making transformations on
//! `g` in the context of the protocol. It is intended to be used in conjunction with
//! [`prove_batch()`] to generate proofs.
use std::iter::zip;
use itertools::Itertools;
use num_traits::{One, Zero};
use thiserror::Error;
use super::utils::UnivariatePoly;
use crate::core::channel::Channel;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
/// Something that can be seen as a multivariate polynomial `g(x_0, ..., x_{n-1})`.
pub trait MultivariatePolyOracle: Sized {
/// Returns the number of variables in `g`.
fn n_variables(&self) -> usize;
/// Computes the sum of `g(x_0, x_1, ..., x_{n-1})` over all `(x_1, ..., x_{n-1})` in
/// `{0, 1}^(n-1)`, effectively reducing the sum over `g` to a univariate polynomial in `x_0`.
///
/// `claim` equals the claimed sum of `g(x_0, x_2, ..., x_{n-1})` over all `(x_0, ..., x_{n-1})`
/// in `{0, 1}^n`. Knowing the claim can help optimize the implementation: Let `f` denote the
/// univariate polynomial we want to return. Note that `claim = f(0) + f(1)` so knowing `claim`
/// and either `f(0)` or `f(1)` allows determining the other.
fn sum_as_poly_in_first_variable(&self, claim: SecureField) -> UnivariatePoly<SecureField>;
/// Returns a transformed oracle where the first variable of `g` is fixed to `challenge`.
///
/// The returned oracle represents the multivariate polynomial `g'`, defined as
/// `g'(x_1, ..., x_{n-1}) = g(challenge, x_1, ..., x_{n-1})`.
fn fix_first_variable(self, challenge: SecureField) -> Self;
}
/// Performs sum-check on a random linear combinations of multiple multivariate polynomials.
///
/// Let the multivariate polynomials be `g_0, ..., g_{n-1}`. A single sum-check is performed on
/// multivariate polynomial `h = g_0 + lambda * g_1 + ... + lambda^(n-1) * g_{n-1}`. The `g_i`s do
/// not need to have the same number of variables. `g_i`s with less variables are folded in the
/// latest possible round of the protocol. For instance with `g_0(x, y, z)` and `g_1(x, y)`
/// sum-check is performed on `h(x, y, z) = g_0(x, y, z) + lambda * g_1(y, z)`. Claim `c_i` should
/// equal the claimed sum of `g_i(x_0, ..., x_{j-1})` over all `(x_0, ..., x_{j-1})` in `{0, 1}^j`.
///
/// The degree of each `g_i` should not exceed [`MAX_DEGREE`] in any variable. The sum-check proof
/// of `h`, list of challenges (variable assignment) and the constant oracles (i.e. the `g_i` with
/// all variables fixed to the their corresponding challenges) are returned.
///
/// Output is of the form: `(proof, variable_assignment, constant_poly_oracles, claimed_evals)`
///
/// # Panics
///
/// Panics if:
/// - No multivariate polynomials are provided.
/// - There aren't the same number of multivariate polynomials and claims.
/// - The degree of any multivariate polynomial exceeds [`MAX_DEGREE`] in any variable.
/// - The round polynomials are inconsistent with their corresponding claimed sum on `0` and `1`.
// TODO: Consider returning constant oracles as separate type.
pub fn prove_batch<O: MultivariatePolyOracle>(
mut claims: Vec<SecureField>,
mut multivariate_polys: Vec<O>,
lambda: SecureField,
channel: &mut impl Channel,
) -> (SumcheckProof, Vec<SecureField>, Vec<O>, Vec<SecureField>) {
let n_variables = multivariate_polys.iter().map(O::n_variables).max().unwrap();
assert_eq!(claims.len(), multivariate_polys.len());
let mut round_polys = Vec::new();
let mut assignment = Vec::new();
// Update the claims for the sum over `h`'s hypercube.
for (claim, multivariate_poly) in zip(&mut claims, &multivariate_polys) {
let n_unused_variables = n_variables - multivariate_poly.n_variables();
*claim *= BaseField::from(1 << n_unused_variables);
}
// Prove sum-check rounds
for round in 0..n_variables {
let n_remaining_rounds = n_variables - round;
let this_round_polys = zip(&multivariate_polys, &claims)
.enumerate()
.map(|(i, (multivariate_poly, &claim))| {
let round_poly = if n_remaining_rounds == multivariate_poly.n_variables() {
multivariate_poly.sum_as_poly_in_first_variable(claim)
} else {
(claim / BaseField::from(2)).into()
};
let eval_at_0 = round_poly.eval_at_point(SecureField::zero());
let eval_at_1 = round_poly.eval_at_point(SecureField::one());
assert_eq!(eval_at_0 + eval_at_1, claim, "i={i}, round={round}");
assert!(round_poly.degree() <= MAX_DEGREE, "i={i}, round={round}");
round_poly
})
.collect_vec();
let round_poly = random_linear_combination(&this_round_polys, lambda);
channel.mix_felts(&round_poly);
let challenge = channel.draw_felt();
claims = this_round_polys
.iter()
.map(|round_poly| round_poly.eval_at_point(challenge))
.collect();
multivariate_polys = multivariate_polys
.into_iter()
.map(|multivariate_poly| {
if n_remaining_rounds != multivariate_poly.n_variables() {
return multivariate_poly;
}
multivariate_poly.fix_first_variable(challenge)
})
.collect();
round_polys.push(round_poly);
assignment.push(challenge);
}
let proof = SumcheckProof { round_polys };
(proof, assignment, multivariate_polys, claims)
}
/// Returns `p_0 + alpha * p_1 + ... + alpha^(n-1) * p_{n-1}`.
fn random_linear_combination(
polys: &[UnivariatePoly<SecureField>],
alpha: SecureField,
) -> UnivariatePoly<SecureField> {
polys
.iter()
.rfold(Zero::zero(), |acc, poly| acc * alpha + poly.clone())
}
/// Partially verifies a sum-check proof.
///
/// Only "partial" since it does not fully verify the prover's claimed evaluation on the variable
/// assignment but checks if the sum of the round polynomials evaluated on `0` and `1` matches the
/// claim for each round. If the proof passes these checks, the variable assignment and the prover's
/// claimed evaluation are returned for the caller to validate otherwise an [`Err`] is returned.
///
/// Output is of the form `(variable_assignment, claimed_eval)`.
pub fn partially_verify(
mut claim: SecureField,
proof: &SumcheckProof,
channel: &mut impl Channel,
) -> Result<(Vec<SecureField>, SecureField), SumcheckError> {
let mut assignment = Vec::new();
for (round, round_poly) in proof.round_polys.iter().enumerate() {
if round_poly.degree() > MAX_DEGREE {
return Err(SumcheckError::DegreeInvalid { round });
}
// TODO: optimize this by sending one less coefficient, and computing it from the
// claim, instead of checking the claim. (Can also be done by quotienting).
let sum = round_poly.eval_at_point(Zero::zero()) + round_poly.eval_at_point(One::one());
if claim != sum {
return Err(SumcheckError::SumInvalid { claim, sum, round });
}
channel.mix_felts(round_poly);
let challenge = channel.draw_felt();
claim = round_poly.eval_at_point(challenge);
assignment.push(challenge);
}
Ok((assignment, claim))
}
#[derive(Debug, Clone)]
pub struct SumcheckProof {
pub round_polys: Vec<UnivariatePoly<SecureField>>,
}
/// Max degree of polynomials the verifier accepts in each round of the protocol.
pub const MAX_DEGREE: usize = 3;
/// Sum-check protocol verification error.
#[derive(Error, Debug)]
pub enum SumcheckError {
#[error("degree of the polynomial in round {round} is too high")]
DegreeInvalid { round: RoundIndex },
#[error("sum does not match the claim in round {round} (sum {sum}, claim {claim})")]
SumInvalid {
claim: SecureField,
sum: SecureField,
round: RoundIndex,
},
}
/// Sum-check round index where 0 corresponds to the first round.
pub type RoundIndex = usize;
#[cfg(test)]
mod tests {
use num_traits::One;
use crate::core::backend::CpuBackend;
use crate::core::channel::{Blake2sChannel, Channel};
use crate::core::fields::qm31::SecureField;
use crate::core::fields::Field;
use crate::core::lookups::mle::Mle;
use crate::core::lookups::sumcheck::{partially_verify, prove_batch};
#[test]
fn sumcheck_works() {
let values = test_channel().draw_felts(32);
let claim = values.iter().sum();
let mle = Mle::<CpuBackend, SecureField>::new(values);
let lambda = SecureField::one();
let (proof, ..) = prove_batch(vec![claim], vec![mle.clone()], lambda, &mut test_channel());
let (assignment, eval) = partially_verify(claim, &proof, &mut test_channel()).unwrap();
assert_eq!(eval, mle.eval_at_point(&assignment));
}
#[test]
fn batch_sumcheck_works() {
let mut channel = test_channel();
let values0 = channel.draw_felts(32);
let values1 = channel.draw_felts(32);
let claim0 = values0.iter().sum();
let claim1 = values1.iter().sum();
let mle0 = Mle::<CpuBackend, SecureField>::new(values0.clone());
let mle1 = Mle::<CpuBackend, SecureField>::new(values1.clone());
let lambda = channel.draw_felt();
let claims = vec![claim0, claim1];
let mles = vec![mle0.clone(), mle1.clone()];
let (proof, ..) = prove_batch(claims, mles, lambda, &mut test_channel());
let claim = claim0 + lambda * claim1;
let (assignment, eval) = partially_verify(claim, &proof, &mut test_channel()).unwrap();
let eval0 = mle0.eval_at_point(&assignment);
let eval1 = mle1.eval_at_point(&assignment);
assert_eq!(eval, eval0 + lambda * eval1);
}
#[test]
fn batch_sumcheck_with_different_n_variables() {
let mut channel = test_channel();
let values0 = channel.draw_felts(64);
let values1 = channel.draw_felts(32);
let claim0 = values0.iter().sum();
let claim1 = values1.iter().sum();
let mle0 = Mle::<CpuBackend, SecureField>::new(values0.clone());
let mle1 = Mle::<CpuBackend, SecureField>::new(values1.clone());
let lambda = channel.draw_felt();
let claims = vec![claim0, claim1];
let mles = vec![mle0.clone(), mle1.clone()];
let (proof, ..) = prove_batch(claims, mles, lambda, &mut test_channel());
let claim = claim0 + lambda * claim1.double();
let (assignment, eval) = partially_verify(claim, &proof, &mut test_channel()).unwrap();
let eval0 = mle0.eval_at_point(&assignment);
let eval1 = mle1.eval_at_point(&assignment[1..]);
assert_eq!(eval, eval0 + lambda * eval1);
}
#[test]
fn invalid_sumcheck_proof_fails() {
let values = test_channel().draw_felts(8);
let claim = values.iter().sum::<SecureField>();
let lambda = SecureField::one();
// Compromise the first value.
let mut invalid_values = values;
invalid_values[0] += SecureField::one();
let invalid_claim = vec![invalid_values.iter().sum::<SecureField>()];
let invalid_mle = vec![Mle::<CpuBackend, SecureField>::new(invalid_values.clone())];
let (invalid_proof, ..) =
prove_batch(invalid_claim, invalid_mle, lambda, &mut test_channel());
assert!(partially_verify(claim, &invalid_proof, &mut test_channel()).is_err());
}
fn test_channel() -> Blake2sChannel {
Blake2sChannel::default()
}
}

View File

@ -0,0 +1,356 @@
use std::iter::{zip, Sum};
use std::ops::{Add, Deref, Mul, Neg, Sub};
use num_traits::{One, Zero};
use crate::core::fields::qm31::SecureField;
use crate::core::fields::{ExtensionOf, Field};
/// Univariate polynomial stored as coefficients in the monomial basis.
#[derive(Debug, Clone)]
pub struct UnivariatePoly<F: Field>(Vec<F>);
impl<F: Field> UnivariatePoly<F> {
pub fn new(coeffs: Vec<F>) -> Self {
let mut polynomial = Self(coeffs);
polynomial.truncate_leading_zeros();
polynomial
}
pub fn eval_at_point(&self, x: F) -> F {
horner_eval(&self.0, x)
}
// <https://en.wikibooks.org/wiki/Algorithm_Implementation/Mathematics/Polynomial_interpolation>
pub fn interpolate_lagrange(xs: &[F], ys: &[F]) -> Self {
assert_eq!(xs.len(), ys.len());
let mut coeffs = Self::zero();
for (i, (&xi, &yi)) in zip(xs, ys).enumerate() {
let mut prod = yi;
for (j, &xj) in xs.iter().enumerate() {
if i != j {
prod /= xi - xj;
}
}
let mut term = Self::new(vec![prod]);
for (j, &xj) in xs.iter().enumerate() {
if i != j {
term = term * (Self::x() - Self::new(vec![xj]));
}
}
coeffs = coeffs + term;
}
coeffs.truncate_leading_zeros();
coeffs
}
pub fn degree(&self) -> usize {
let mut coeffs = self.0.iter().rev();
let _ = (&mut coeffs).take_while(|v| v.is_zero());
coeffs.len().saturating_sub(1)
}
fn x() -> Self {
Self(vec![F::zero(), F::one()])
}
fn truncate_leading_zeros(&mut self) {
while self.0.last() == Some(&F::zero()) {
self.0.pop();
}
}
}
impl<F: Field> From<F> for UnivariatePoly<F> {
fn from(value: F) -> Self {
Self::new(vec![value])
}
}
impl<F: Field> Mul<F> for UnivariatePoly<F> {
type Output = Self;
fn mul(mut self, rhs: F) -> Self {
self.0.iter_mut().for_each(|coeff| *coeff *= rhs);
self
}
}
impl<F: Field> Mul for UnivariatePoly<F> {
type Output = Self;
fn mul(mut self, mut rhs: Self) -> Self {
if self.is_zero() || rhs.is_zero() {
return Self::zero();
}
self.truncate_leading_zeros();
rhs.truncate_leading_zeros();
let mut res = vec![F::zero(); self.0.len() + rhs.0.len() - 1];
for (i, coeff_a) in self.0.into_iter().enumerate() {
for (j, &coeff_b) in rhs.0.iter().enumerate() {
res[i + j] += coeff_a * coeff_b;
}
}
Self::new(res)
}
}
impl<F: Field> Add for UnivariatePoly<F> {
type Output = Self;
fn add(self, rhs: Self) -> Self {
let n = self.0.len().max(rhs.0.len());
let mut res = Vec::new();
for i in 0..n {
res.push(match (self.0.get(i), rhs.0.get(i)) {
(Some(&a), Some(&b)) => a + b,
(Some(&a), None) | (None, Some(&a)) => a,
_ => unreachable!(),
})
}
Self(res)
}
}
impl<F: Field> Sub for UnivariatePoly<F> {
type Output = Self;
fn sub(self, rhs: Self) -> Self {
self + (-rhs)
}
}
impl<F: Field> Neg for UnivariatePoly<F> {
type Output = Self;
fn neg(self) -> Self {
Self(self.0.into_iter().map(|v| -v).collect())
}
}
impl<F: Field> Zero for UnivariatePoly<F> {
fn zero() -> Self {
Self(vec![])
}
fn is_zero(&self) -> bool {
self.0.iter().all(F::is_zero)
}
}
impl<F: Field> Deref for UnivariatePoly<F> {
type Target = [F];
fn deref(&self) -> &[F] {
&self.0
}
}
/// Evaluates univariate polynomial using [Horner's method].
///
/// [Horner's method]: https://en.wikipedia.org/wiki/Horner%27s_method
pub fn horner_eval<F: Field>(coeffs: &[F], x: F) -> F {
coeffs
.iter()
.rfold(F::zero(), |acc, &coeff| acc * x + coeff)
}
/// Returns `v_0 + alpha * v_1 + ... + alpha^(n-1) * v_{n-1}`.
pub fn random_linear_combination(v: &[SecureField], alpha: SecureField) -> SecureField {
horner_eval(v, alpha)
}
/// Evaluates the lagrange kernel of the boolean hypercube.
///
/// The lagrange kernel of the boolean hypercube is a multilinear extension of the function that
/// when given `x, y` in `{0, 1}^n` evaluates to 1 if `x = y`, and evaluates to 0 otherwise.
pub fn eq<F: Field>(x: &[F], y: &[F]) -> F {
assert_eq!(x.len(), y.len());
zip(x, y)
.map(|(&xi, &yi)| xi * yi + (F::one() - xi) * (F::one() - yi))
.product()
}
/// Computes `eq(0, assignment) * eval0 + eq(1, assignment) * eval1`.
pub fn fold_mle_evals<F>(assignment: SecureField, eval0: F, eval1: F) -> SecureField
where
F: Field,
SecureField: ExtensionOf<F>,
{
assignment * (eval1 - eval0) + eval0
}
/// Projective fraction.
#[derive(Debug, Clone, Copy)]
pub struct Fraction<N, D> {
pub numerator: N,
pub denominator: D,
}
impl<N, D> Fraction<N, D> {
pub fn new(numerator: N, denominator: D) -> Self {
Self {
numerator,
denominator,
}
}
}
impl<N, D: Add<Output = D> + Add<N, Output = D> + Mul<N, Output = D> + Mul<Output = D> + Copy> Add
for Fraction<N, D>
{
type Output = Fraction<D, D>;
fn add(self, rhs: Self) -> Fraction<D, D> {
Fraction {
numerator: rhs.denominator * self.numerator + self.denominator * rhs.numerator,
denominator: self.denominator * rhs.denominator,
}
}
}
impl<N: Zero, D: One + Zero> Zero for Fraction<N, D>
where
Self: Add<Output = Self>,
{
fn zero() -> Self {
Self {
numerator: N::zero(),
denominator: D::one(),
}
}
fn is_zero(&self) -> bool {
self.numerator.is_zero() && !self.denominator.is_zero()
}
}
impl<N, D> Sum for Fraction<N, D>
where
Self: Zero,
{
fn sum<I: Iterator<Item = Self>>(mut iter: I) -> Self {
let first = iter.next().unwrap_or_else(Self::zero);
iter.fold(first, |a, b| a + b)
}
}
/// Represents the fraction `1 / x`
pub struct Reciprocal<T> {
x: T,
}
impl<T> Reciprocal<T> {
pub fn new(x: T) -> Self {
Self { x }
}
}
impl<T: Add<Output = T> + Mul<Output = T> + Copy> Add for Reciprocal<T> {
type Output = Fraction<T, T>;
fn add(self, rhs: Self) -> Fraction<T, T> {
// `1/a + 1/b = (a + b)/(a * b)`
Fraction {
numerator: self.x + rhs.x,
denominator: self.x * rhs.x,
}
}
}
#[cfg(test)]
mod tests {
use std::iter::zip;
use num_traits::{One, Zero};
use super::{horner_eval, UnivariatePoly};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;
use crate::core::lookups::utils::{eq, Fraction};
#[test]
fn lagrange_interpolation_works() {
let xs = [5, 1, 3, 9].map(BaseField::from);
let ys = [1, 2, 3, 4].map(BaseField::from);
let poly = UnivariatePoly::interpolate_lagrange(&xs, &ys);
for (x, y) in zip(xs, ys) {
assert_eq!(poly.eval_at_point(x), y, "mismatch for x={x}");
}
}
#[test]
fn horner_eval_works() {
let coeffs = [BaseField::from(9), BaseField::from(2), BaseField::from(3)];
let x = BaseField::from(7);
let eval = horner_eval(&coeffs, x);
assert_eq!(eval, coeffs[0] + coeffs[1] * x + coeffs[2] * x.square());
}
#[test]
fn eq_identical_hypercube_points_returns_one() {
let zero = SecureField::zero();
let one = SecureField::one();
let a = &[one, zero, one];
let eq_eval = eq(a, a);
assert_eq!(eq_eval, one);
}
#[test]
fn eq_different_hypercube_points_returns_zero() {
let zero = SecureField::zero();
let one = SecureField::one();
let a = &[one, zero, one];
let b = &[one, zero, zero];
let eq_eval = eq(a, b);
assert_eq!(eq_eval, zero);
}
#[test]
#[should_panic]
fn eq_different_size_points() {
let zero = SecureField::zero();
let one = SecureField::one();
eq(&[zero, one], &[zero]);
}
#[test]
fn fraction_addition_works() {
let a = Fraction::new(BaseField::from(1), BaseField::from(3));
let b = Fraction::new(BaseField::from(2), BaseField::from(6));
let Fraction {
numerator,
denominator,
} = a + b;
assert_eq!(
numerator / denominator,
BaseField::from(2) / BaseField::from(3)
);
}
}

View File

@ -0,0 +1,59 @@
use std::ops::{Deref, DerefMut};
pub mod air;
pub mod backend;
pub mod channel;
pub mod circle;
pub mod constraints;
pub mod fft;
pub mod fields;
pub mod fri;
pub mod lookups;
pub mod pcs;
pub mod poly;
pub mod proof_of_work;
pub mod prover;
pub mod queries;
#[cfg(test)]
pub mod test_utils;
pub mod utils;
pub mod vcs;
/// A vector in which each element relates (by index) to a column in the trace.
pub type ColumnVec<T> = Vec<T>;
/// A vector of [ColumnVec]s. Each [ColumnVec] relates (by index) to a component in the air.
#[derive(Debug, Clone)]
pub struct ComponentVec<T>(pub Vec<ColumnVec<T>>);
impl<T> ComponentVec<T> {
pub fn flatten(self) -> ColumnVec<T> {
self.0.into_iter().flatten().collect()
}
}
impl<T> ComponentVec<ColumnVec<T>> {
pub fn flatten_cols(self) -> Vec<T> {
self.0.into_iter().flatten().flatten().collect()
}
}
impl<T> Default for ComponentVec<T> {
fn default() -> Self {
Self(Vec::new())
}
}
impl<T> Deref for ComponentVec<T> {
type Target = Vec<ColumnVec<T>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> DerefMut for ComponentVec<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

View File

@ -0,0 +1,40 @@
//! Implements a FRI polynomial commitment scheme.
//! This is a protocol where the prover can commit on a set of polynomials and then prove their
//! opening on a set of points.
//! Note: This implementation is not really a polynomial commitment scheme, because we are not in
//! the unique decoding regime. This is enough for a STARK proof though, where we only want to imply
//! the existence of such polynomials, and are ok with having a small decoding list.
//! Note: Opened points cannot come from the commitment domain.
mod prover;
pub mod quotients;
mod utils;
mod verifier;
pub use self::prover::{
CommitmentSchemeProof, CommitmentSchemeProver, CommitmentTreeProver, TreeBuilder,
};
pub use self::utils::TreeVec;
pub use self::verifier::CommitmentSchemeVerifier;
use super::fri::FriConfig;
#[derive(Copy, Debug, Clone, PartialEq, Eq)]
pub struct TreeSubspan {
pub tree_index: usize,
pub col_start: usize,
pub col_end: usize,
}
#[derive(Debug, Clone, Copy)]
pub struct PcsConfig {
pub pow_bits: u32,
pub fri_config: FriConfig,
}
impl Default for PcsConfig {
fn default() -> Self {
Self {
pow_bits: 5,
fri_config: FriConfig::new(0, 1, 3),
}
}
}

View File

@ -0,0 +1,256 @@
use std::collections::BTreeMap;
use itertools::Itertools;
use tracing::{span, Level};
use super::super::circle::CirclePoint;
use super::super::fields::m31::BaseField;
use super::super::fields::qm31::SecureField;
use super::super::fri::{FriProof, FriProver};
use super::super::poly::circle::CanonicCoset;
use super::super::poly::BitReversedOrder;
use super::super::ColumnVec;
use super::quotients::{compute_fri_quotients, PointSample};
use super::utils::TreeVec;
use super::{PcsConfig, TreeSubspan};
use crate::core::air::Trace;
use crate::core::backend::BackendForChannel;
use crate::core::channel::{Channel, MerkleChannel};
use crate::core::poly::circle::{CircleEvaluation, CirclePoly};
use crate::core::poly::twiddles::TwiddleTree;
use crate::core::vcs::ops::MerkleHasher;
use crate::core::vcs::prover::{MerkleDecommitment, MerkleProver};
/// The prover side of a FRI polynomial commitment scheme. See [super].
pub struct CommitmentSchemeProver<'a, B: BackendForChannel<MC>, MC: MerkleChannel> {
pub trees: TreeVec<CommitmentTreeProver<B, MC>>,
pub config: PcsConfig,
twiddles: &'a TwiddleTree<B>,
}
impl<'a, B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentSchemeProver<'a, B, MC> {
pub fn new(config: PcsConfig, twiddles: &'a TwiddleTree<B>) -> Self {
CommitmentSchemeProver {
trees: TreeVec::default(),
config,
twiddles,
}
}
fn commit(&mut self, polynomials: ColumnVec<CirclePoly<B>>, channel: &mut MC::C) {
let _span = span!(Level::INFO, "Commitment").entered();
let tree = CommitmentTreeProver::new(
polynomials,
self.config.fri_config.log_blowup_factor,
channel,
self.twiddles,
);
self.trees.push(tree);
}
pub fn tree_builder(&mut self) -> TreeBuilder<'_, 'a, B, MC> {
TreeBuilder {
tree_index: self.trees.len(),
commitment_scheme: self,
polys: Vec::default(),
}
}
pub fn roots(&self) -> TreeVec<<MC::H as MerkleHasher>::Hash> {
self.trees.as_ref().map(|tree| tree.commitment.root())
}
pub fn polynomials(&self) -> TreeVec<ColumnVec<&CirclePoly<B>>> {
self.trees
.as_ref()
.map(|tree| tree.polynomials.iter().collect())
}
pub fn evaluations(
&self,
) -> TreeVec<ColumnVec<&CircleEvaluation<B, BaseField, BitReversedOrder>>> {
self.trees
.as_ref()
.map(|tree| tree.evaluations.iter().collect())
}
pub fn trace(&self) -> Trace<'_, B> {
let polys = self.polynomials();
let evals = self.evaluations();
Trace { polys, evals }
}
pub fn prove_values(
&self,
sampled_points: TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>>,
channel: &mut MC::C,
) -> CommitmentSchemeProof<MC::H> {
// Evaluate polynomials on open points.
let span = span!(Level::INFO, "Evaluate columns out of domain").entered();
let samples = self
.polynomials()
.zip_cols(&sampled_points)
.map_cols(|(poly, points)| {
points
.iter()
.map(|&point| PointSample {
point,
value: poly.eval_at_point(point),
})
.collect_vec()
});
span.exit();
let sampled_values = samples
.as_cols_ref()
.map_cols(|x| x.iter().map(|o| o.value).collect());
channel.mix_felts(&sampled_values.clone().flatten_cols());
// Compute oods quotients for boundary constraints on the sampled points.
let columns = self.evaluations().flatten();
let quotients = compute_fri_quotients(
&columns,
&samples.flatten(),
channel.draw_felt(),
self.config.fri_config.log_blowup_factor,
);
// Run FRI commitment phase on the oods quotients.
let fri_prover =
FriProver::<B, MC>::commit(channel, self.config.fri_config, &quotients, self.twiddles);
// Proof of work.
let span1 = span!(Level::INFO, "Grind").entered();
let proof_of_work = B::grind(channel, self.config.pow_bits);
span1.exit();
channel.mix_nonce(proof_of_work);
// FRI decommitment phase.
let (fri_proof, fri_query_domains) = fri_prover.decommit(channel);
// Decommit the FRI queries on the merkle trees.
let decommitment_results = self.trees.as_ref().map(|tree| {
let queries = fri_query_domains
.iter()
.map(|(&log_size, domain)| (log_size, domain.flatten()))
.collect();
tree.decommit(queries)
});
let queried_values = decommitment_results.as_ref().map(|(v, _)| v.clone());
let decommitments = decommitment_results.map(|(_, d)| d);
CommitmentSchemeProof {
sampled_values,
decommitments,
queried_values,
proof_of_work,
fri_proof,
}
}
}
#[derive(Debug)]
pub struct CommitmentSchemeProof<H: MerkleHasher> {
pub sampled_values: TreeVec<ColumnVec<Vec<SecureField>>>,
pub decommitments: TreeVec<MerkleDecommitment<H>>,
pub queried_values: TreeVec<ColumnVec<Vec<BaseField>>>,
pub proof_of_work: u64,
pub fri_proof: FriProof<H>,
}
pub struct TreeBuilder<'a, 'b, B: BackendForChannel<MC>, MC: MerkleChannel> {
tree_index: usize,
commitment_scheme: &'a mut CommitmentSchemeProver<'b, B, MC>,
polys: ColumnVec<CirclePoly<B>>,
}
impl<'a, 'b, B: BackendForChannel<MC>, MC: MerkleChannel> TreeBuilder<'a, 'b, B, MC> {
pub fn extend_evals(
&mut self,
columns: ColumnVec<CircleEvaluation<B, BaseField, BitReversedOrder>>,
) -> TreeSubspan {
let span = span!(Level::INFO, "Interpolation for commitment").entered();
let col_start = self.polys.len();
let polys = columns
.into_iter()
.map(|eval| eval.interpolate_with_twiddles(self.commitment_scheme.twiddles))
.collect_vec();
span.exit();
self.polys.extend(polys);
TreeSubspan {
tree_index: self.tree_index,
col_start,
col_end: self.polys.len(),
}
}
pub fn extend_polys(&mut self, polys: ColumnVec<CirclePoly<B>>) -> TreeSubspan {
let col_start = self.polys.len();
self.polys.extend(polys);
TreeSubspan {
tree_index: self.tree_index,
col_start,
col_end: self.polys.len(),
}
}
pub fn commit(self, channel: &mut MC::C) {
let _span = span!(Level::INFO, "Commitment").entered();
self.commitment_scheme.commit(self.polys, channel);
}
}
/// Prover data for a single commitment tree in a commitment scheme. The commitment scheme allows to
/// commit on a set of polynomials at a time. This corresponds to such a set.
pub struct CommitmentTreeProver<B: BackendForChannel<MC>, MC: MerkleChannel> {
pub polynomials: ColumnVec<CirclePoly<B>>,
pub evaluations: ColumnVec<CircleEvaluation<B, BaseField, BitReversedOrder>>,
pub commitment: MerkleProver<B, MC::H>,
}
impl<B: BackendForChannel<MC>, MC: MerkleChannel> CommitmentTreeProver<B, MC> {
pub fn new(
polynomials: ColumnVec<CirclePoly<B>>,
log_blowup_factor: u32,
channel: &mut MC::C,
twiddles: &TwiddleTree<B>,
) -> Self {
let span = span!(Level::INFO, "Extension").entered();
let evaluations = polynomials
.iter()
.map(|poly| {
poly.evaluate_with_twiddles(
CanonicCoset::new(poly.log_size() + log_blowup_factor).circle_domain(),
twiddles,
)
})
.collect_vec();
span.exit();
let _span = span!(Level::INFO, "Merkle").entered();
let tree = MerkleProver::commit(evaluations.iter().map(|eval| &eval.values).collect());
MC::mix_root(channel, tree.root());
CommitmentTreeProver {
polynomials,
evaluations,
commitment: tree,
}
}
/// Decommits the merkle tree on the given query positions.
/// Returns the values at the queried positions and the decommitment.
/// The queries are given as a mapping from the log size of the layer size to the queried
/// positions on each column of that size.
fn decommit(
&self,
queries: BTreeMap<u32, Vec<usize>>,
) -> (ColumnVec<Vec<BaseField>>, MerkleDecommitment<MC::H>) {
let eval_vec = self
.evaluations
.iter()
.map(|eval| &eval.values)
.collect_vec();
self.commitment.decommit(queries, eval_vec)
}
}

View File

@ -0,0 +1,218 @@
use std::cmp::Reverse;
use std::collections::BTreeMap;
use std::iter::zip;
use itertools::{izip, multiunzip, Itertools};
use tracing::{span, Level};
use crate::core::backend::cpu::quotients::{accumulate_row_quotients, quotient_constants};
use crate::core::circle::CirclePoint;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fri::SparseCircleEvaluation;
use crate::core::poly::circle::{
CanonicCoset, CircleDomain, CircleEvaluation, PolyOps, SecureEvaluation,
};
use crate::core::poly::BitReversedOrder;
use crate::core::prover::VerificationError;
use crate::core::queries::SparseSubCircleDomain;
use crate::core::utils::bit_reverse_index;
pub trait QuotientOps: PolyOps {
/// Accumulates the quotients of the columns at the given domain.
/// For a column f(x), and a point sample (p,v), the quotient is
/// (f(x) - V0(x))/V1(x)
/// where V0(p)=v, V0(conj(p))=conj(v), and V1 is a vanishing polynomial for p,conj(p).
/// This ensures that if f(p)=v, then the quotient is a polynomial.
/// The result is a linear combination of the quotients using powers of random_coeff.
fn accumulate_quotients(
domain: CircleDomain,
columns: &[&CircleEvaluation<Self, BaseField, BitReversedOrder>],
random_coeff: SecureField,
sample_batches: &[ColumnSampleBatch],
log_blowup_factor: u32,
) -> SecureEvaluation<Self, BitReversedOrder>;
}
/// A batch of column samplings at a point.
pub struct ColumnSampleBatch {
/// The point at which the columns are sampled.
pub point: CirclePoint<SecureField>,
/// The sampled column indices and their values at the point.
pub columns_and_values: Vec<(usize, SecureField)>,
}
impl ColumnSampleBatch {
/// Groups column samples by sampled point.
/// # Arguments
/// samples: For each column, a vector of samples.
pub fn new_vec(samples: &[&Vec<PointSample>]) -> Vec<Self> {
// Group samples by point, and create a ColumnSampleBatch for each point.
// This should keep a stable ordering.
let mut grouped_samples = BTreeMap::new();
for (column_index, samples) in samples.iter().enumerate() {
for sample in samples.iter() {
grouped_samples
.entry(sample.point)
.or_insert_with(Vec::new)
.push((column_index, sample.value));
}
}
grouped_samples
.into_iter()
.map(|(point, columns_and_values)| ColumnSampleBatch {
point,
columns_and_values,
})
.collect()
}
}
pub struct PointSample {
pub point: CirclePoint<SecureField>,
pub value: SecureField,
}
pub fn compute_fri_quotients<B: QuotientOps>(
columns: &[&CircleEvaluation<B, BaseField, BitReversedOrder>],
samples: &[Vec<PointSample>],
random_coeff: SecureField,
log_blowup_factor: u32,
) -> Vec<SecureEvaluation<B, BitReversedOrder>> {
let _span = span!(Level::INFO, "Compute FRI quotients").entered();
zip(columns, samples)
.sorted_by_key(|(c, _)| Reverse(c.domain.log_size()))
.group_by(|(c, _)| c.domain.log_size())
.into_iter()
.map(|(log_size, tuples)| {
let (columns, samples): (Vec<_>, Vec<_>) = tuples.unzip();
let domain = CanonicCoset::new(log_size).circle_domain();
// TODO: slice.
let sample_batches = ColumnSampleBatch::new_vec(&samples);
B::accumulate_quotients(
domain,
&columns,
random_coeff,
&sample_batches,
log_blowup_factor,
)
})
.collect()
}
pub fn fri_answers(
column_log_sizes: Vec<u32>,
samples: &[Vec<PointSample>],
random_coeff: SecureField,
query_domain_per_log_size: BTreeMap<u32, SparseSubCircleDomain>,
queried_values_per_column: &[Vec<BaseField>],
) -> Result<Vec<SparseCircleEvaluation>, VerificationError> {
izip!(column_log_sizes, samples, queried_values_per_column)
.sorted_by_key(|(log_size, ..)| Reverse(*log_size))
.group_by(|(log_size, ..)| *log_size)
.into_iter()
.map(|(log_size, tuples)| {
let (_, samples, queried_valued_per_column): (Vec<_>, Vec<_>, Vec<_>) =
multiunzip(tuples);
fri_answers_for_log_size(
log_size,
&samples, // Here it's the points and sampled values (in the proof)
random_coeff,
&query_domain_per_log_size[&log_size],
&queried_valued_per_column, // Here it's queried values (in the proof)
)
})
.collect()
}
pub fn fri_answers_for_log_size(
log_size: u32,
samples: &[&Vec<PointSample>],
random_coeff: SecureField,
query_domain: &SparseSubCircleDomain,
queried_values_per_column: &[&Vec<BaseField>],
) -> Result<SparseCircleEvaluation, VerificationError> {
let commitment_domain = CanonicCoset::new(log_size).circle_domain();
let sample_batches = ColumnSampleBatch::new_vec(samples);
for queried_values in queried_values_per_column {
if queried_values.len() != query_domain.flatten().len() {
return Err(VerificationError::InvalidStructure(
"Insufficient number of queried values".to_string(),
));
}
}
let mut queried_values_per_column = queried_values_per_column
.iter()
.map(|q| q.iter())
.collect_vec();
let mut evals = Vec::new();
for subdomain in query_domain.iter() {
let domain = subdomain.to_circle_domain(&commitment_domain);
let quotient_constants = quotient_constants(&sample_batches, random_coeff, domain);
let mut column_evals = Vec::new();
for queried_values in queried_values_per_column.iter_mut() {
let eval = CircleEvaluation::new(
domain,
queried_values.take(domain.size()).copied().collect_vec(),
);
column_evals.push(eval);
}
let mut values = Vec::new();
for row in 0..domain.size() {
let domain_point = domain.at(bit_reverse_index(row, log_size));
let value = accumulate_row_quotients(
&sample_batches,
&column_evals.iter().collect_vec(),
&quotient_constants,
row,
domain_point,
);
values.push(value);
}
let eval = CircleEvaluation::new(domain, values);
evals.push(eval);
}
let res = SparseCircleEvaluation::new(evals);
if !queried_values_per_column.iter().all(|x| x.is_empty()) {
return Err(VerificationError::InvalidStructure(
"Too many queried values".to_string(),
));
}
Ok(res)
}
#[cfg(test)]
mod tests {
use crate::core::backend::cpu::{CpuCircleEvaluation, CpuCirclePoly};
use crate::core::circle::SECURE_FIELD_CIRCLE_GEN;
use crate::core::pcs::quotients::{compute_fri_quotients, PointSample};
use crate::core::poly::circle::CanonicCoset;
use crate::{m31, qm31};
#[test]
fn test_quotients_are_low_degree() {
const LOG_SIZE: u32 = 7;
const LOG_BLOWUP_FACTOR: u32 = 1;
let polynomial = CpuCirclePoly::new((0..1 << LOG_SIZE).map(|i| m31!(i)).collect());
let eval_domain = CanonicCoset::new(LOG_SIZE + 1).circle_domain();
let eval = polynomial.evaluate(eval_domain);
let point = SECURE_FIELD_CIRCLE_GEN;
let value = polynomial.eval_at_point(point);
let coeff = qm31!(1, 2, 3, 4);
let quot_eval = compute_fri_quotients(
&[&eval],
&[vec![PointSample { point, value }]],
coeff,
LOG_BLOWUP_FACTOR,
)
.pop()
.unwrap();
let quot_poly_base_field =
CpuCircleEvaluation::new(eval_domain, quot_eval.values.columns[0].clone())
.interpolate();
assert!(quot_poly_base_field.is_in_fri_space(LOG_SIZE));
}
}

View File

@ -0,0 +1,158 @@
use std::collections::BTreeSet;
use std::ops::{Deref, DerefMut};
use itertools::zip_eq;
use serde::{Deserialize, Serialize};
use super::TreeSubspan;
use crate::core::ColumnVec;
/// A container that holds an element for each commitment tree.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TreeVec<T>(pub Vec<T>);
impl<T> TreeVec<T> {
pub fn new(vec: Vec<T>) -> TreeVec<T> {
TreeVec(vec)
}
pub fn map<U, F: Fn(T) -> U>(self, f: F) -> TreeVec<U> {
TreeVec(self.0.into_iter().map(f).collect())
}
pub fn zip<U>(self, other: impl Into<TreeVec<U>>) -> TreeVec<(T, U)> {
let other = other.into();
TreeVec(self.0.into_iter().zip(other.0).collect())
}
pub fn zip_eq<U>(self, other: impl Into<TreeVec<U>>) -> TreeVec<(T, U)> {
let other = other.into();
TreeVec(zip_eq(self.0, other.0).collect())
}
pub fn as_ref(&self) -> TreeVec<&T> {
TreeVec(self.iter().collect())
}
pub fn as_mut(&mut self) -> TreeVec<&mut T> {
TreeVec(self.iter_mut().collect())
}
}
/// Converts `&TreeVec<T>` to `TreeVec<&T>`.
impl<'a, T> From<&'a TreeVec<T>> for TreeVec<&'a T> {
fn from(val: &'a TreeVec<T>) -> Self {
val.as_ref()
}
}
impl<T> Deref for TreeVec<T> {
type Target = Vec<T>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> DerefMut for TreeVec<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<T> Default for TreeVec<T> {
fn default() -> Self {
TreeVec(Vec::new())
}
}
impl<T> TreeVec<ColumnVec<T>> {
pub fn map_cols<U, F: FnMut(T) -> U>(self, mut f: F) -> TreeVec<ColumnVec<U>> {
TreeVec(
self.0
.into_iter()
.map(|column| column.into_iter().map(&mut f).collect())
.collect(),
)
}
/// Zips two [`TreeVec<ColumVec<T>>`] with the same structure (number of columns in each tree).
/// The resulting [`TreeVec<ColumVec<T>>`] has the same structure, with each value being a tuple
/// of the corresponding values from the input [`TreeVec<ColumVec<T>>`].
pub fn zip_cols<U>(
self,
other: impl Into<TreeVec<ColumnVec<U>>>,
) -> TreeVec<ColumnVec<(T, U)>> {
let other = other.into();
TreeVec(
zip_eq(self.0, other.0)
.map(|(column1, column2)| zip_eq(column1, column2).collect())
.collect(),
)
}
pub fn as_cols_ref(&self) -> TreeVec<ColumnVec<&T>> {
TreeVec(self.iter().map(|column| column.iter().collect()).collect())
}
/// Flattens the [`TreeVec<ColumVec<T>>`] into a single [`ColumnVec`] with all the columns
/// combined.
pub fn flatten(self) -> ColumnVec<T> {
self.0.into_iter().flatten().collect()
}
/// Appends the columns of another [`TreeVec<ColumVec<T>>`] to this one.
pub fn append_cols(&mut self, mut other: TreeVec<ColumnVec<T>>) {
let n_trees = self.0.len().max(other.0.len());
self.0.resize_with(n_trees, Default::default);
for (self_col, other_col) in self.0.iter_mut().zip(other.0.iter_mut()) {
self_col.append(other_col);
}
}
/// Concatenates the columns of multiple [`TreeVec<ColumVec<T>>`] into a single
/// [`TreeVec<ColumVec<T>>`].
pub fn concat_cols(
trees: impl Iterator<Item = TreeVec<ColumnVec<T>>>,
) -> TreeVec<ColumnVec<T>> {
let mut result = TreeVec::default();
for tree in trees {
result.append_cols(tree);
}
result
}
/// Extracts a sub-tree based on the specified locations.
///
/// # Panics
///
/// If two or more locations have the same tree index.
pub fn sub_tree(&self, locations: &[TreeSubspan]) -> TreeVec<ColumnVec<&T>> {
let tree_indicies: BTreeSet<usize> = locations.iter().map(|l| l.tree_index).collect();
assert_eq!(tree_indicies.len(), locations.len());
let max_tree_index = tree_indicies.iter().max().unwrap_or(&0);
let mut res = TreeVec(vec![Vec::new(); max_tree_index + 1]);
for &location in locations {
// TODO(andrew): Throwing error here might be better instead.
let chunk = self.get_chunk(location).unwrap();
res[location.tree_index] = chunk;
}
res
}
fn get_chunk(&self, location: TreeSubspan) -> Option<ColumnVec<&T>> {
let tree = self.0.get(location.tree_index)?;
let chunk = tree.get(location.col_start..location.col_end)?;
Some(chunk.iter().collect())
}
}
impl<'a, T> From<&'a TreeVec<ColumnVec<T>>> for TreeVec<ColumnVec<&'a T>> {
fn from(val: &'a TreeVec<ColumnVec<T>>) -> Self {
val.as_cols_ref()
}
}
impl<T> TreeVec<ColumnVec<Vec<T>>> {
/// Flattens a [`TreeVec<ColumVec<T>>`] of [Vec]s into a single [Vec] with all the elements
/// combined.
pub fn flatten_cols(self) -> Vec<T> {
self.0.into_iter().flatten().flatten().collect()
}
}

View File

@ -0,0 +1,134 @@
use std::iter::zip;
use itertools::Itertools;
use super::super::circle::CirclePoint;
use super::super::fields::qm31::SecureField;
use super::super::fri::{CirclePolyDegreeBound, FriVerifier};
use super::quotients::{fri_answers, PointSample};
use super::utils::TreeVec;
use super::{CommitmentSchemeProof, PcsConfig};
use crate::core::channel::{Channel, MerkleChannel};
use crate::core::prover::VerificationError;
use crate::core::vcs::ops::MerkleHasher;
use crate::core::vcs::verifier::MerkleVerifier;
use crate::core::ColumnVec;
/// The verifier side of a FRI polynomial commitment scheme. See [super].
#[derive(Default)]
pub struct CommitmentSchemeVerifier<MC: MerkleChannel> {
pub trees: TreeVec<MerkleVerifier<MC::H>>,
pub config: PcsConfig,
}
impl<MC: MerkleChannel> CommitmentSchemeVerifier<MC> {
pub fn new(config: PcsConfig) -> Self {
Self {
trees: TreeVec::default(),
config,
}
}
/// A [TreeVec<ColumnVec>] of the log sizes of each column in each commitment tree.
fn column_log_sizes(&self) -> TreeVec<ColumnVec<u32>> {
self.trees
.as_ref()
.map(|tree| tree.column_log_sizes.clone())
}
/// Reads a commitment from the prover.
pub fn commit(
&mut self,
commitment: <MC::H as MerkleHasher>::Hash,
log_sizes: &[u32],
channel: &mut MC::C,
) {
MC::mix_root(channel, commitment);
let extended_log_sizes = log_sizes
.iter()
.map(|&log_size| log_size + self.config.fri_config.log_blowup_factor)
.collect();
let verifier = MerkleVerifier::new(commitment, extended_log_sizes);
self.trees.push(verifier);
}
pub fn verify_values(
&self,
sampled_points: TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>>,
proof: CommitmentSchemeProof<MC::H>,
channel: &mut MC::C,
) -> Result<(), VerificationError> {
channel.mix_felts(&proof.sampled_values.clone().flatten_cols());
let random_coeff = channel.draw_felt();
let bounds = self
.column_log_sizes()
.zip_cols(&sampled_points)
.map_cols(|(log_size, sampled_points)| {
vec![
CirclePolyDegreeBound::new(log_size - self.config.fri_config.log_blowup_factor);
sampled_points.len()
]
})
.flatten_cols()
.into_iter()
.sorted()
.rev()
.dedup()
.collect_vec();
// FRI commitment phase on OODS quotients.
let mut fri_verifier =
FriVerifier::<MC>::commit(channel, self.config.fri_config, proof.fri_proof, bounds)?;
// Verify proof of work.
channel.mix_nonce(proof.proof_of_work);
if channel.trailing_zeros() < self.config.pow_bits {
return Err(VerificationError::ProofOfWork);
}
// Get FRI query domains.
let fri_query_domains = fri_verifier.column_query_positions(channel);
// Verify merkle decommitments.
self.trees
.as_ref()
.zip_eq(proof.decommitments)
.zip_eq(proof.queried_values.clone())
.map(|((tree, decommitment), queried_values)| {
let queries = fri_query_domains
.iter()
.map(|(&log_size, domain)| (log_size, domain.flatten()))
.collect();
tree.verify(queries, queried_values, decommitment)
})
.0
.into_iter()
.collect::<Result<_, _>>()?;
println!("DONE");
// Answer FRI queries.
let samples = sampled_points // Sample point is always the same and the value is the one provided in the proof
.zip_cols(proof.sampled_values)
.map_cols(|(sampled_points, sampled_values)| {
zip(sampled_points, sampled_values)
.map(|(point, value)| PointSample { point, value })
.collect_vec()
})
.flatten();
// TODO(spapini): Properly defined column log size and dinstinguish between poly and
// commitment.
let fri_answers = fri_answers(
self.column_log_sizes().flatten().into_iter().collect(),
&samples,
random_coeff,
fri_query_domains,
&proof.queried_values.flatten(),
)?;
fri_verifier.decommit(fri_answers)?;
Ok(())
}
}

View File

@ -0,0 +1,77 @@
use super::CircleDomain;
use crate::core::circle::{CirclePoint, CirclePointIndex, Coset};
use crate::core::fields::m31::BaseField;
/// A coset of the form G_{2n} + <G_n>, where G_n is the generator of the
/// subgroup of order n. The ordering on this coset is G_2n + i * G_n.
/// These cosets can be used as a [CircleDomain], and be interpolated on.
/// Note that this changes the ordering on the coset to be like [CircleDomain],
/// which is G_2n + i * G_n/2 and then -G_2n -i * G_n/2.
/// For example, the Xs below are a canonic coset with n=8.
/// ```text
/// X O X
/// O O
/// X X
/// O O
/// X X
/// O O
/// X O X
/// ```
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct CanonicCoset {
pub coset: Coset,
}
impl CanonicCoset {
pub fn new(log_size: u32) -> Self {
assert!(log_size > 0);
Self {
coset: Coset::odds(log_size),
}
}
/// Gets the full coset represented G_{2n} + <G_n>.
pub fn coset(&self) -> Coset {
self.coset
}
/// Gets half of the coset (its conjugate complements to the whole coset), G_{2n} + <G_{n/2}>
pub fn half_coset(&self) -> Coset {
Coset::half_odds(self.log_size() - 1)
}
/// Gets the [CircleDomain] representing the same point set (in another order).
pub fn circle_domain(&self) -> CircleDomain {
CircleDomain::new(self.half_coset())
}
/// Returns the log size of the coset.
pub fn log_size(&self) -> u32 {
self.coset.log_size
}
/// Returns the size of the coset.
pub fn size(&self) -> usize {
self.coset.size()
}
pub fn initial_index(&self) -> CirclePointIndex {
self.coset.initial_index
}
pub fn step_size(&self) -> CirclePointIndex {
self.coset.step_size
}
pub fn step(&self) -> CirclePoint<BaseField> {
self.coset.step
}
pub fn index_at(&self, index: usize) -> CirclePointIndex {
self.coset.index_at(index)
}
pub fn at(&self, i: usize) -> CirclePoint<BaseField> {
self.coset.at(i)
}
}

View File

@ -0,0 +1,188 @@
use std::iter::Chain;
use itertools::Itertools;
use crate::core::circle::{
CirclePoint, CirclePointIndex, Coset, CosetIterator, M31_CIRCLE_LOG_ORDER,
};
use crate::core::fields::m31::BaseField;
pub const MAX_CIRCLE_DOMAIN_LOG_SIZE: u32 = M31_CIRCLE_LOG_ORDER - 1;
/// A valid domain for circle polynomial interpolation and evaluation.
/// Valid domains are a disjoint union of two conjugate cosets: +-C + <G_n>.
/// The ordering defined on this domain is C + iG_n, and then -C - iG_n.
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct CircleDomain {
pub half_coset: Coset,
}
impl CircleDomain {
/// Given a coset C + <G_n>, constructs the circle domain +-C + <G_n> (i.e.,
/// this coset and its conjugate).
pub fn new(half_coset: Coset) -> Self {
Self { half_coset }
}
pub fn iter(&self) -> CircleDomainIterator {
self.half_coset
.iter()
.chain(self.half_coset.conjugate().iter())
}
/// Iterates over point indices.
pub fn iter_indices(&self) -> CircleDomainIndexIterator {
self.half_coset
.iter_indices()
.chain(self.half_coset.conjugate().iter_indices())
}
/// Returns the size of the domain.
pub fn size(&self) -> usize {
1 << self.log_size()
}
/// Returns the log size of the domain.
pub fn log_size(&self) -> u32 {
self.half_coset.log_size + 1
}
/// Returns the `i` th domain element.
pub fn at(&self, i: usize) -> CirclePoint<BaseField> {
self.index_at(i).to_point()
}
/// Returns the [CirclePointIndex] of the `i`th domain element.
pub fn index_at(&self, i: usize) -> CirclePointIndex {
if i < self.half_coset.size() {
self.half_coset.index_at(i)
} else {
-self.half_coset.index_at(i - self.half_coset.size())
}
}
pub fn find(&self, i: CirclePointIndex) -> Option<usize> {
if let Some(d) = self.half_coset.find(i) {
return Some(d);
}
if let Some(d) = self.half_coset.conjugate().find(i) {
return Some(self.half_coset.size() + d);
}
None
}
/// Returns true if the domain is canonic.
///
/// Canonic domains are domains with elements that are the entire set of points defined by
/// `G_2n + <G_n>` where `G_n` and `G_2n` are obtained by repeatedly doubling
/// [crate::core::circle::M31_CIRCLE_GEN].
pub fn is_canonic(&self) -> bool {
self.half_coset.initial_index * 4 == self.half_coset.step_size
}
/// Splits a circle domain into a smaller [CircleDomain]s, shifted by offsets.
pub fn split(&self, log_parts: u32) -> (CircleDomain, Vec<CirclePointIndex>) {
assert!(log_parts <= self.half_coset.log_size);
let subdomain = CircleDomain::new(Coset::new(
self.half_coset.initial_index,
self.half_coset.log_size - log_parts,
));
let shifts = (0..1 << log_parts)
.map(|i| self.half_coset.step_size * i)
.collect_vec();
(subdomain, shifts)
}
pub fn shift(&self, shift: CirclePointIndex) -> CircleDomain {
CircleDomain::new(self.half_coset.shift(shift))
}
}
impl IntoIterator for CircleDomain {
type Item = CirclePoint<BaseField>;
type IntoIter = CircleDomainIterator;
/// Iterates over the points in the domain.
fn into_iter(self) -> CircleDomainIterator {
self.iter()
}
}
/// An iterator over points in a circle domain.
///
/// Let the domain be `+-c + <G>`. The first iterated points are `c + <G>`, then `-c + <-G>`.
pub type CircleDomainIterator =
Chain<CosetIterator<CirclePoint<BaseField>>, CosetIterator<CirclePoint<BaseField>>>;
/// Like [CircleDomainIterator] but returns corresponding [CirclePointIndex]s.
type CircleDomainIndexIterator =
Chain<CosetIterator<CirclePointIndex>, CosetIterator<CirclePointIndex>>;
#[cfg(test)]
mod tests {
use itertools::Itertools;
use super::CircleDomain;
use crate::core::circle::{CirclePointIndex, Coset};
use crate::core::poly::circle::CanonicCoset;
#[test]
fn test_circle_domain_iterator() {
let domain = CircleDomain::new(Coset::new(CirclePointIndex::generator(), 2));
for (i, point) in domain.iter().enumerate() {
if i < 4 {
assert_eq!(
point,
(CirclePointIndex::generator() + CirclePointIndex::subgroup_gen(2) * i)
.to_point()
);
} else {
assert_eq!(
point,
(-(CirclePointIndex::generator() + CirclePointIndex::subgroup_gen(2) * i))
.to_point()
);
}
}
}
#[test]
fn is_canonic_invalid_domain() {
let half_coset = Coset::new(CirclePointIndex::generator(), 4);
let not_canonic_domain = CircleDomain::new(half_coset);
assert!(!not_canonic_domain.is_canonic());
}
#[test]
fn test_at_circle_domain() {
let domain = CanonicCoset::new(7).circle_domain();
let half_domain_size = domain.size() / 2;
for i in 0..half_domain_size {
assert_eq!(domain.index_at(i), -domain.index_at(i + half_domain_size));
assert_eq!(domain.at(i), domain.at(i + half_domain_size).conjugate());
}
}
#[test]
fn test_domain_split() {
let domain = CanonicCoset::new(5).circle_domain();
let (subdomain, shifts) = domain.split(2);
let domain_points = domain.iter().collect::<Vec<_>>();
let points_for_each_domain = shifts
.iter()
.map(|&shift| (subdomain.shift(shift)).iter().collect_vec())
.collect::<Vec<_>>();
// Interleave the points from each subdomain.
let extended_points = (0..(1 << 3))
.flat_map(|point_ind| {
(0..(1 << 2))
.map(|shift_ind| points_for_each_domain[shift_ind][point_ind])
.collect_vec()
})
.collect_vec();
assert_eq!(domain_points, extended_points);
}
}

View File

@ -0,0 +1,218 @@
use std::marker::PhantomData;
use std::ops::{Deref, Index};
use educe::Educe;
use super::{CanonicCoset, CircleDomain, CirclePoly, PolyOps};
use crate::core::backend::cpu::CpuCircleEvaluation;
use crate::core::backend::{Col, Column};
use crate::core::circle::{CirclePointIndex, Coset};
use crate::core::fields::m31::BaseField;
use crate::core::fields::{ExtensionOf, FieldOps};
use crate::core::poly::twiddles::TwiddleTree;
use crate::core::poly::{BitReversedOrder, NaturalOrder};
use crate::core::utils::bit_reverse_index;
/// An evaluation defined on a [CircleDomain].
/// The values are ordered according to the [CircleDomain] ordering.
#[derive(Educe)]
#[educe(Clone, Debug)]
pub struct CircleEvaluation<B: FieldOps<F>, F: ExtensionOf<BaseField>, EvalOrder = NaturalOrder> {
pub domain: CircleDomain,
pub values: Col<B, F>,
_eval_order: PhantomData<EvalOrder>,
}
impl<B: FieldOps<F>, F: ExtensionOf<BaseField>, EvalOrder> CircleEvaluation<B, F, EvalOrder> {
pub fn new(domain: CircleDomain, values: Col<B, F>) -> Self {
assert_eq!(domain.size(), values.len());
Self {
domain,
values,
_eval_order: PhantomData,
}
}
}
// Note: The concrete implementation of the poly operations is in the specific backend used.
// For example, the CPU backend implementation is in `src/core/backend/cpu/poly.rs`.
impl<F: ExtensionOf<BaseField>, B: FieldOps<F>> CircleEvaluation<B, F, NaturalOrder> {
// TODO(spapini): Remove. Is this even used.
pub fn get_at(&self, point_index: CirclePointIndex) -> F {
self.values
.at(self.domain.find(point_index).expect("Not in domain"))
}
pub fn bit_reverse(mut self) -> CircleEvaluation<B, F, BitReversedOrder> {
B::bit_reverse_column(&mut self.values);
CircleEvaluation::new(self.domain, self.values)
}
}
impl<F: ExtensionOf<BaseField>> CpuCircleEvaluation<F, NaturalOrder> {
pub fn fetch_eval_on_coset(&self, coset: Coset) -> CosetSubEvaluation<'_, F> {
assert!(coset.log_size() <= self.domain.half_coset.log_size());
if let Some(offset) = self.domain.half_coset.find(coset.initial_index) {
return CosetSubEvaluation::new(
&self.values[..self.domain.half_coset.size()],
offset,
coset.step_size / self.domain.half_coset.step_size,
);
}
if let Some(offset) = self.domain.half_coset.conjugate().find(coset.initial_index) {
return CosetSubEvaluation::new(
&self.values[self.domain.half_coset.size()..],
offset,
(-coset.step_size) / self.domain.half_coset.step_size,
);
}
panic!("Coset not found in domain");
}
}
impl<B: PolyOps> CircleEvaluation<B, BaseField, BitReversedOrder> {
/// Creates a [CircleEvaluation] from values ordered according to
/// [CanonicCoset]. For example, the canonic coset might look like this:
/// G_8, G_8 + G_4, G_8 + 2G_4, G_8 + 3G_4.
/// The circle domain will be ordered like this:
/// G_8, G_8 + 2G_4, -G_8, -G_8 - 2G_4.
pub fn new_canonical_ordered(coset: CanonicCoset, values: Col<B, BaseField>) -> Self {
B::new_canonical_ordered(coset, values)
}
/// Computes a minimal [CirclePoly] that evaluates to the same values as this evaluation.
pub fn interpolate(self) -> CirclePoly<B> {
let coset = self.domain.half_coset;
B::interpolate(self, &B::precompute_twiddles(coset))
}
/// Computes a minimal [CirclePoly] that evaluates to the same values as this evaluation, using
/// precomputed twiddles.
pub fn interpolate_with_twiddles(self, twiddles: &TwiddleTree<B>) -> CirclePoly<B> {
B::interpolate(self, twiddles)
}
}
impl<B: FieldOps<F>, F: ExtensionOf<BaseField>> CircleEvaluation<B, F, BitReversedOrder> {
pub fn bit_reverse(mut self) -> CircleEvaluation<B, F, NaturalOrder> {
B::bit_reverse_column(&mut self.values);
CircleEvaluation::new(self.domain, self.values)
}
pub fn get_at(&self, point_index: CirclePointIndex) -> F {
self.values.at(bit_reverse_index(
self.domain.find(point_index).expect("Not in domain"),
self.domain.log_size(),
))
}
}
impl<B: FieldOps<F>, F: ExtensionOf<BaseField>, EvalOrder> Deref
for CircleEvaluation<B, F, EvalOrder>
{
type Target = Col<B, F>;
fn deref(&self) -> &Self::Target {
&self.values
}
}
/// A part of a [CircleEvaluation], for a specific coset that is a subset of the circle domain.
pub struct CosetSubEvaluation<'a, F: ExtensionOf<BaseField>> {
evaluation: &'a [F],
offset: usize,
step: isize,
}
impl<'a, F: ExtensionOf<BaseField>> CosetSubEvaluation<'a, F> {
fn new(evaluation: &'a [F], offset: usize, step: isize) -> Self {
assert!(evaluation.len().is_power_of_two());
Self {
evaluation,
offset,
step,
}
}
}
impl<'a, F: ExtensionOf<BaseField>> Index<isize> for CosetSubEvaluation<'a, F> {
type Output = F;
fn index(&self, index: isize) -> &Self::Output {
let index =
((self.offset as isize) + index * self.step) & ((self.evaluation.len() - 1) as isize);
&self.evaluation[index as usize]
}
}
impl<'a, F: ExtensionOf<BaseField>> Index<usize> for CosetSubEvaluation<'a, F> {
type Output = F;
fn index(&self, index: usize) -> &Self::Output {
&self[index as isize]
}
}
#[cfg(test)]
mod tests {
use crate::core::backend::cpu::CpuCircleEvaluation;
use crate::core::circle::Coset;
use crate::core::fields::m31::BaseField;
use crate::core::poly::circle::CanonicCoset;
use crate::core::poly::NaturalOrder;
use crate::m31;
#[test]
fn test_interpolate_non_canonic() {
let domain = CanonicCoset::new(3).circle_domain();
assert_eq!(domain.log_size(), 3);
let evaluation = CpuCircleEvaluation::<_, NaturalOrder>::new(
domain,
(0..8).map(BaseField::from_u32_unchecked).collect(),
)
.bit_reverse();
let poly = evaluation.interpolate();
for (i, point) in domain.iter().enumerate() {
assert_eq!(poly.eval_at_point(point.into_ef()), m31!(i as u32).into());
}
}
#[test]
fn test_interpolate_canonic() {
let coset = CanonicCoset::new(3);
let evaluation = CpuCircleEvaluation::new_canonical_ordered(
coset,
(0..8).map(BaseField::from_u32_unchecked).collect(),
);
let poly = evaluation.interpolate();
for (i, point) in Coset::odds(3).iter().enumerate() {
assert_eq!(poly.eval_at_point(point.into_ef()), m31!(i as u32).into());
}
}
#[test]
pub fn test_get_at_circle_evaluation() {
let domain = CanonicCoset::new(7).circle_domain();
let values = (0..domain.size()).map(|i| m31!(i as u32)).collect();
let circle_evaluation = CpuCircleEvaluation::<_, NaturalOrder>::new(domain, values);
let bit_reversed_circle_evaluation = circle_evaluation.clone().bit_reverse();
for index in domain.iter_indices() {
assert_eq!(
circle_evaluation.get_at(index),
bit_reversed_circle_evaluation.get_at(index)
);
}
}
#[test]
fn test_sub_evaluation() {
let domain = CanonicCoset::new(7).circle_domain();
let values = (0..domain.size()).map(|i| m31!(i as u32)).collect();
let circle_evaluation = CpuCircleEvaluation::new(domain, values);
let coset = Coset::new(domain.index_at(17), 3);
let sub_eval = circle_evaluation.fetch_eval_on_coset(coset);
for i in 0..coset.size() {
assert_eq!(sub_eval[i], circle_evaluation.get_at(coset.index_at(i)));
}
}
}

View File

@ -0,0 +1,56 @@
mod canonic;
mod domain;
mod evaluation;
mod ops;
mod poly;
mod secure_poly;
pub use canonic::CanonicCoset;
pub use domain::{CircleDomain, MAX_CIRCLE_DOMAIN_LOG_SIZE};
pub use evaluation::{CircleEvaluation, CosetSubEvaluation};
pub use ops::PolyOps;
pub use poly::CirclePoly;
pub use secure_poly::{SecureCirclePoly, SecureEvaluation};
#[cfg(test)]
mod tests {
use super::CanonicCoset;
use crate::core::backend::cpu::CpuCircleEvaluation;
use crate::core::fields::m31::BaseField;
use crate::core::utils::bit_reverse_index;
#[test]
fn test_interpolate_and_eval() {
let domain = CanonicCoset::new(3).circle_domain();
assert_eq!(domain.log_size(), 3);
let evaluation =
CpuCircleEvaluation::new(domain, (0..8).map(BaseField::from_u32_unchecked).collect());
let poly = evaluation.clone().interpolate();
let evaluation2 = poly.evaluate(domain);
assert_eq!(evaluation.values, evaluation2.values);
}
#[test]
fn is_canonic_valid_domain() {
let canonic_domain = CanonicCoset::new(4).circle_domain();
assert!(canonic_domain.is_canonic());
}
#[test]
pub fn test_bit_reverse_indices() {
let log_domain_size = 7;
let log_small_domain_size = 5;
let domain = CanonicCoset::new(log_domain_size);
let small_domain = CanonicCoset::new(log_small_domain_size);
let n_folds = log_domain_size - log_small_domain_size;
for i in 0..2usize.pow(log_domain_size) {
let point = domain.at(bit_reverse_index(i, log_domain_size));
let small_point = small_domain.at(bit_reverse_index(
i / 2usize.pow(n_folds),
log_small_domain_size,
));
assert_eq!(point.repeated_double(n_folds), small_point);
}
}
}

View File

@ -0,0 +1,48 @@
use super::{CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly};
use crate::core::backend::Col;
use crate::core::circle::{CirclePoint, Coset};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldOps;
use crate::core::poly::twiddles::TwiddleTree;
use crate::core::poly::BitReversedOrder;
/// Operations on BaseField polynomials.
pub trait PolyOps: FieldOps<BaseField> + Sized {
// TODO(spapini): Use a column instead of this type.
/// The type for precomputed twiddles.
type Twiddles;
/// Creates a [CircleEvaluation] from values ordered according to [CanonicCoset].
/// Used by the [`CircleEvaluation::new_canonical_ordered()`] function.
fn new_canonical_ordered(
coset: CanonicCoset,
values: Col<Self, BaseField>,
) -> CircleEvaluation<Self, BaseField, BitReversedOrder>;
/// Computes a minimal [CirclePoly] that evaluates to the same values as this evaluation.
/// Used by the [`CircleEvaluation::interpolate()`] function.
fn interpolate(
eval: CircleEvaluation<Self, BaseField, BitReversedOrder>,
itwiddles: &TwiddleTree<Self>,
) -> CirclePoly<Self>;
/// Evaluates the polynomial at a single point.
/// Used by the [`CirclePoly::eval_at_point()`] function.
fn eval_at_point(poly: &CirclePoly<Self>, point: CirclePoint<SecureField>) -> SecureField;
/// Extends the polynomial to a larger degree bound.
/// Used by the [`CirclePoly::extend()`] function.
fn extend(poly: &CirclePoly<Self>, log_size: u32) -> CirclePoly<Self>;
/// Evaluates the polynomial at all points in the domain.
/// Used by the [`CirclePoly::evaluate()`] function.
fn evaluate(
poly: &CirclePoly<Self>,
domain: CircleDomain,
twiddles: &TwiddleTree<Self>,
) -> CircleEvaluation<Self, BaseField, BitReversedOrder>;
/// Precomputes twiddles for a given coset.
fn precompute_twiddles(coset: Coset) -> TwiddleTree<Self>;
}

View File

@ -0,0 +1,118 @@
use super::{CircleDomain, CircleEvaluation, PolyOps};
use crate::core::backend::{Col, Column};
use crate::core::circle::CirclePoint;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldOps;
use crate::core::poly::twiddles::TwiddleTree;
use crate::core::poly::BitReversedOrder;
/// A polynomial defined on a [CircleDomain].
#[derive(Clone, Debug)]
pub struct CirclePoly<B: FieldOps<BaseField>> {
/// Coefficients of the polynomial in the FFT basis.
/// Note: These are not the coefficients of the polynomial in the standard
/// monomial basis. The FFT basis is a tensor product of the twiddles:
/// y, x, pi(x), pi^2(x), ..., pi^{log_size-2}(x).
/// pi(x) := 2x^2 - 1.
pub coeffs: Col<B, BaseField>,
/// The number of coefficients stored as `log2(len(coeffs))`.
log_size: u32,
}
impl<B: PolyOps> CirclePoly<B> {
/// Creates a new circle polynomial.
///
/// Coefficients must be in the circle IFFT algorithm's basis stored in bit-reversed order.
///
/// # Panics
///
/// Panics if the number of coefficients isn't a power of two.
pub fn new(coeffs: Col<B, BaseField>) -> Self {
assert!(coeffs.len().is_power_of_two());
let log_size = coeffs.len().ilog2();
Self { log_size, coeffs }
}
pub fn log_size(&self) -> u32 {
self.log_size
}
/// Evaluates the polynomial at a single point.
pub fn eval_at_point(&self, point: CirclePoint<SecureField>) -> SecureField {
B::eval_at_point(self, point)
}
/// Extends the polynomial to a larger degree bound.
pub fn extend(&self, log_size: u32) -> Self {
B::extend(self, log_size)
}
/// Evaluates the polynomial at all points in the domain.
pub fn evaluate(
&self,
domain: CircleDomain,
) -> CircleEvaluation<B, BaseField, BitReversedOrder> {
B::evaluate(self, domain, &B::precompute_twiddles(domain.half_coset))
}
/// Evaluates the polynomial at all points in the domain, using precomputed twiddles.
pub fn evaluate_with_twiddles(
&self,
domain: CircleDomain,
twiddles: &TwiddleTree<B>,
) -> CircleEvaluation<B, BaseField, BitReversedOrder> {
B::evaluate(self, domain, twiddles)
}
}
#[cfg(test)]
impl crate::core::backend::cpu::CpuCirclePoly {
pub fn is_in_fft_space(&self, log_fft_size: u32) -> bool {
use num_traits::Zero;
let mut coeffs = self.coeffs.clone();
while coeffs.last() == Some(&BaseField::zero()) {
coeffs.pop();
}
// The highest degree monomial in a fft-space polynomial is x^{(n/2) - 1}y.
// And it is at offset (n-1). x^{(n/2)} is at offset `n`, and is not allowed.
let highest_degree_allowed_monomial_offset = 1 << log_fft_size;
coeffs.len() <= highest_degree_allowed_monomial_offset
}
/// Fri space is the space of polynomials of total degree n/2.
/// Highest degree monomials are x^{n/2} and x^{(n/2)-1}y.
pub fn is_in_fri_space(&self, log_fft_size: u32) -> bool {
use num_traits::Zero;
let mut coeffs = self.coeffs.clone();
while coeffs.last() == Some(&BaseField::zero()) {
coeffs.pop();
}
// x^{n/2} is at offset `n`, and is the last offset allowed to be non-zero.
let highest_degree_monomial_offset = (1 << log_fft_size) + 1;
coeffs.len() <= highest_degree_monomial_offset
}
}
#[cfg(test)]
mod tests {
use crate::core::backend::cpu::CpuCirclePoly;
use crate::core::circle::CirclePoint;
use crate::core::fields::m31::BaseField;
#[test]
fn test_circle_poly_extend() {
let poly = CpuCirclePoly::new((0..16).map(BaseField::from_u32_unchecked).collect());
let extended = poly.clone().extend(8);
let random_point = CirclePoint::get_point(21903);
assert_eq!(
poly.eval_at_point(random_point),
extended.eval_at_point(random_point)
);
}
}

View File

@ -0,0 +1,118 @@
use std::marker::PhantomData;
use std::ops::{Deref, DerefMut};
use super::{CircleDomain, CircleEvaluation, CirclePoly, PolyOps};
use crate::core::backend::CpuBackend;
use crate::core::circle::CirclePoint;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::{SecureColumnByCoords, SECURE_EXTENSION_DEGREE};
use crate::core::fields::FieldOps;
use crate::core::poly::twiddles::TwiddleTree;
use crate::core::poly::BitReversedOrder;
pub struct SecureCirclePoly<B: FieldOps<BaseField>>(pub [CirclePoly<B>; SECURE_EXTENSION_DEGREE]);
impl<B: PolyOps> SecureCirclePoly<B> {
pub fn eval_at_point(&self, point: CirclePoint<SecureField>) -> SecureField {
SecureField::from_partial_evals(self.eval_columns_at_point(point))
}
pub fn eval_columns_at_point(
&self,
point: CirclePoint<SecureField>,
) -> [SecureField; SECURE_EXTENSION_DEGREE] {
[
self[0].eval_at_point(point),
self[1].eval_at_point(point),
self[2].eval_at_point(point),
self[3].eval_at_point(point),
]
}
pub fn log_size(&self) -> u32 {
self[0].log_size()
}
pub fn evaluate_with_twiddles(
&self,
domain: CircleDomain,
twiddles: &TwiddleTree<B>,
) -> SecureEvaluation<B, BitReversedOrder> {
let polys = self.0.each_ref();
let columns = polys.map(|poly| poly.evaluate_with_twiddles(domain, twiddles).values);
SecureEvaluation::new(domain, SecureColumnByCoords { columns })
}
}
impl<B: FieldOps<BaseField>> Deref for SecureCirclePoly<B> {
type Target = [CirclePoly<B>; SECURE_EXTENSION_DEGREE];
fn deref(&self) -> &Self::Target {
&self.0
}
}
/// A [`SecureField`] evaluation defined on a [CircleDomain].
///
/// The evaluation is stored as a column major array of [`SECURE_EXTENSION_DEGREE`] many base field
/// evaluations. The evaluations are ordered according to the [CircleDomain] ordering.
#[derive(Clone)]
pub struct SecureEvaluation<B: FieldOps<BaseField>, EvalOrder> {
pub domain: CircleDomain,
pub values: SecureColumnByCoords<B>,
_eval_order: PhantomData<EvalOrder>,
}
impl<B: FieldOps<BaseField>, EvalOrder> SecureEvaluation<B, EvalOrder> {
pub fn new(domain: CircleDomain, values: SecureColumnByCoords<B>) -> Self {
assert_eq!(domain.size(), values.len());
Self {
domain,
values,
_eval_order: PhantomData,
}
}
pub fn into_coordinate_evals(
self,
) -> [CircleEvaluation<B, BaseField, EvalOrder>; SECURE_EXTENSION_DEGREE] {
let Self { domain, values, .. } = self;
values.columns.map(|c| CircleEvaluation::new(domain, c))
}
}
impl<B: FieldOps<BaseField>, EvalOrder> Deref for SecureEvaluation<B, EvalOrder> {
type Target = SecureColumnByCoords<B>;
fn deref(&self) -> &Self::Target {
&self.values
}
}
impl<B: FieldOps<BaseField>, EvalOrder> DerefMut for SecureEvaluation<B, EvalOrder> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.values
}
}
impl<B: PolyOps> SecureEvaluation<B, BitReversedOrder> {
/// Computes a minimal [`SecureCirclePoly`] that evaluates to the same values as this
/// evaluation, using precomputed twiddles.
pub fn interpolate_with_twiddles(self, twiddles: &TwiddleTree<B>) -> SecureCirclePoly<B> {
let domain = self.domain;
let cols = self.values.columns;
SecureCirclePoly(cols.map(|c| {
CircleEvaluation::<B, BaseField, BitReversedOrder>::new(domain, c)
.interpolate_with_twiddles(twiddles)
}))
}
}
impl<EvalOrder> From<CircleEvaluation<CpuBackend, SecureField, EvalOrder>>
for SecureEvaluation<CpuBackend, EvalOrder>
{
fn from(evaluation: CircleEvaluation<CpuBackend, SecureField, EvalOrder>) -> Self {
Self::new(evaluation.domain, evaluation.values.into_iter().collect())
}
}

Some files were not shown because too many files have changed in this diff Show More