diff --git a/Cargo.toml b/Cargo.toml index 357a6278..34b7e449 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["ecdsa", "evm", "field", "insertion", "maybe_rayon", "plonky2", "starky", "u32", "util", "waksman"] +members = ["evm", "field", "maybe_rayon", "plonky2", "starky", "util"] [profile.release] opt-level = 3 diff --git a/README.md b/README.md index a72edd5f..ce800dbd 100644 --- a/README.md +++ b/README.md @@ -59,3 +59,7 @@ Plonky2's default hash function is Poseidon, configured with 8 full rounds, 22 p ## Links - [System Zero](https://github.com/mir-protocol/system-zero), a zkVM built on top of Starky (no longer maintained) +- [Waksman](https://github.com/mir-protocol/plonky2-waksman), Plonky2 gadgets for permutation checking using Waksman networks (no longer maintained) +- [Insertion](https://github.com/mir-protocol/plonky2-insertion), Plonky2 gadgets for insertion into a list (no longer maintained) +- [u32](https://github.com/mir-protocol/plonky2-u32), Plonky2 gadgets for u32 arithmetic (no longer actively maintained) +- [ECDSA](https://github.com/mir-protocol/plonky2-ecdsa), Plonky2 gadgets for the ECDSA algorithm (no longer actively maintained) diff --git a/ecdsa/Cargo.toml b/ecdsa/Cargo.toml deleted file mode 100644 index 0a156654..00000000 --- a/ecdsa/Cargo.toml +++ /dev/null @@ -1,21 +0,0 @@ -[package] -name = "plonky2_ecdsa" -description = "ECDSA gadget for Plonky2" -version = "0.1.0" -license = "MIT OR Apache-2.0" -edition = "2021" - -[features] -parallel = ["plonky2_maybe_rayon/parallel", "plonky2/parallel"] - -[dependencies] -anyhow = { version = "1.0.40", default-features = false } -itertools = { version = "0.10.0", default-features = false } -plonky2_maybe_rayon = { version = "0.1.0", default-features = false } -num = { version = "0.4.0", default-features = false } -plonky2 = { version = "0.1.2", default-features = false } -plonky2_u32 = { version = "0.1.0", default-features = false } -serde = { version = "1.0", default-features = false, features = ["derive"] } - -[dev-dependencies] -rand = { version = "0.8.4", default-features = false, features = ["getrandom"] } diff --git a/ecdsa/LICENSE-APACHE b/ecdsa/LICENSE-APACHE deleted file mode 100644 index 1e5006dc..00000000 --- a/ecdsa/LICENSE-APACHE +++ /dev/null @@ -1,202 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - -TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - -1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - -2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - -3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - -4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - -5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - -6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - -7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - -8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - -9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - -END OF TERMS AND CONDITIONS - -APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - -Copyright [yyyy] [name of copyright owner] - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - diff --git a/ecdsa/LICENSE-MIT b/ecdsa/LICENSE-MIT deleted file mode 100644 index 86d690b2..00000000 --- a/ecdsa/LICENSE-MIT +++ /dev/null @@ -1,21 +0,0 @@ -The MIT License (MIT) - -Copyright (c) 2022 The Plonky2 Authors - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. diff --git a/ecdsa/README.md b/ecdsa/README.md deleted file mode 100644 index bb4e2d8a..00000000 --- a/ecdsa/README.md +++ /dev/null @@ -1,13 +0,0 @@ -## License - -Licensed under either of - -* Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) -* MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) - -at your option. - - -### Contribution - -Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in the work by you, as defined in the Apache-2.0 license, shall be dual licensed as above, without any additional terms or conditions. diff --git a/ecdsa/src/curve/curve_adds.rs b/ecdsa/src/curve/curve_adds.rs deleted file mode 100644 index 319c5614..00000000 --- a/ecdsa/src/curve/curve_adds.rs +++ /dev/null @@ -1,158 +0,0 @@ -use core::ops::Add; - -use plonky2::field::ops::Square; -use plonky2::field::types::Field; - -use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint}; - -impl Add> for ProjectivePoint { - type Output = ProjectivePoint; - - fn add(self, rhs: ProjectivePoint) -> Self::Output { - let ProjectivePoint { - x: x1, - y: y1, - z: z1, - } = self; - let ProjectivePoint { - x: x2, - y: y2, - z: z2, - } = rhs; - - if z1 == C::BaseField::ZERO { - return rhs; - } - if z2 == C::BaseField::ZERO { - return self; - } - - let x1z2 = x1 * z2; - let y1z2 = y1 * z2; - let x2z1 = x2 * z1; - let y2z1 = y2 * z1; - - // Check if we're doubling or adding inverses. - if x1z2 == x2z1 { - if y1z2 == y2z1 { - // TODO: inline to avoid redundant muls. - return self.double(); - } - if y1z2 == -y2z1 { - return ProjectivePoint::ZERO; - } - } - - // From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/addition/add-1998-cmo-2 - let z1z2 = z1 * z2; - let u = y2z1 - y1z2; - let uu = u.square(); - let v = x2z1 - x1z2; - let vv = v.square(); - let vvv = v * vv; - let r = vv * x1z2; - let a = uu * z1z2 - vvv - r.double(); - let x3 = v * a; - let y3 = u * (r - a) - vvv * y1z2; - let z3 = vvv * z1z2; - ProjectivePoint::nonzero(x3, y3, z3) - } -} - -impl Add> for ProjectivePoint { - type Output = ProjectivePoint; - - fn add(self, rhs: AffinePoint) -> Self::Output { - let ProjectivePoint { - x: x1, - y: y1, - z: z1, - } = self; - let AffinePoint { - x: x2, - y: y2, - zero: zero2, - } = rhs; - - if z1 == C::BaseField::ZERO { - return rhs.to_projective(); - } - if zero2 { - return self; - } - - let x2z1 = x2 * z1; - let y2z1 = y2 * z1; - - // Check if we're doubling or adding inverses. - if x1 == x2z1 { - if y1 == y2z1 { - // TODO: inline to avoid redundant muls. - return self.double(); - } - if y1 == -y2z1 { - return ProjectivePoint::ZERO; - } - } - - // From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/addition/madd-1998-cmo - let u = y2z1 - y1; - let uu = u.square(); - let v = x2z1 - x1; - let vv = v.square(); - let vvv = v * vv; - let r = vv * x1; - let a = uu * z1 - vvv - r.double(); - let x3 = v * a; - let y3 = u * (r - a) - vvv * y1; - let z3 = vvv * z1; - ProjectivePoint::nonzero(x3, y3, z3) - } -} - -impl Add> for AffinePoint { - type Output = ProjectivePoint; - - fn add(self, rhs: AffinePoint) -> Self::Output { - let AffinePoint { - x: x1, - y: y1, - zero: zero1, - } = self; - let AffinePoint { - x: x2, - y: y2, - zero: zero2, - } = rhs; - - if zero1 { - return rhs.to_projective(); - } - if zero2 { - return self.to_projective(); - } - - // Check if we're doubling or adding inverses. - if x1 == x2 { - if y1 == y2 { - return self.to_projective().double(); - } - if y1 == -y2 { - return ProjectivePoint::ZERO; - } - } - - // From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/addition/mmadd-1998-cmo - let u = y2 - y1; - let uu = u.square(); - let v = x2 - x1; - let vv = v.square(); - let vvv = v * vv; - let r = vv * x1; - let a = uu - vvv - r.double(); - let x3 = v * a; - let y3 = u * (r - a) - vvv * y1; - let z3 = vvv; - ProjectivePoint::nonzero(x3, y3, z3) - } -} diff --git a/ecdsa/src/curve/curve_msm.rs b/ecdsa/src/curve/curve_msm.rs deleted file mode 100644 index 9faa4a79..00000000 --- a/ecdsa/src/curve/curve_msm.rs +++ /dev/null @@ -1,265 +0,0 @@ -use alloc::vec::Vec; - -use itertools::Itertools; -use plonky2::field::types::{Field, PrimeField}; -use plonky2_maybe_rayon::*; - -use crate::curve::curve_summation::affine_multisummation_best; -use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint}; - -/// In Yao's method, we compute an affine summation for each digit. In a parallel setting, it would -/// be easiest to assign individual summations to threads, but this would be sub-optimal because -/// multi-summations can be more efficient than repeating individual summations (see -/// `affine_multisummation_best`). Thus we divide digits into large chunks, and assign chunks of -/// digits to threads. Note that there is a delicate balance here, as large chunks can result in -/// uneven distributions of work among threads. -const DIGITS_PER_CHUNK: usize = 80; - -#[derive(Clone, Debug)] -pub struct MsmPrecomputation { - /// For each generator (in the order they were passed to `msm_precompute`), contains a vector - /// of powers, i.e. [(2^w)^i] for i < DIGITS. - // TODO: Use compressed coordinates here. - powers_per_generator: Vec>>, - - /// The window size. - w: usize, -} - -pub fn msm_precompute( - generators: &[ProjectivePoint], - w: usize, -) -> MsmPrecomputation { - MsmPrecomputation { - powers_per_generator: generators - .into_par_iter() - .map(|&g| precompute_single_generator(g, w)) - .collect(), - w, - } -} - -fn precompute_single_generator(g: ProjectivePoint, w: usize) -> Vec> { - let digits = (C::ScalarField::BITS + w - 1) / w; - let mut powers: Vec> = Vec::with_capacity(digits); - powers.push(g); - for i in 1..digits { - let mut power_i_proj = powers[i - 1]; - for _j in 0..w { - power_i_proj = power_i_proj.double(); - } - powers.push(power_i_proj); - } - ProjectivePoint::batch_to_affine(&powers) -} - -pub fn msm_parallel( - scalars: &[C::ScalarField], - generators: &[ProjectivePoint], - w: usize, -) -> ProjectivePoint { - let precomputation = msm_precompute(generators, w); - msm_execute_parallel(&precomputation, scalars) -} - -pub fn msm_execute( - precomputation: &MsmPrecomputation, - scalars: &[C::ScalarField], -) -> ProjectivePoint { - assert_eq!(precomputation.powers_per_generator.len(), scalars.len()); - let w = precomputation.w; - let digits = (C::ScalarField::BITS + w - 1) / w; - let base = 1 << w; - - // This is a variant of Yao's method, adapted to the multi-scalar setting. Because we use - // extremely large windows, the repeated scans in Yao's method could be more expensive than the - // actual group operations. To avoid this, we store a multimap from each possible digit to the - // positions in which that digit occurs in the scalars. These positions have the form (i, j), - // where i is the index of the generator and j is an index into the digits of the scalar - // associated with that generator. - let mut digit_occurrences: Vec> = Vec::with_capacity(digits); - for _i in 0..base { - digit_occurrences.push(Vec::new()); - } - for (i, scalar) in scalars.iter().enumerate() { - let digits = to_digits::(scalar, w); - for (j, &digit) in digits.iter().enumerate() { - digit_occurrences[digit].push((i, j)); - } - } - - let mut y = ProjectivePoint::ZERO; - let mut u = ProjectivePoint::ZERO; - - for digit in (1..base).rev() { - for &(i, j) in &digit_occurrences[digit] { - u = u + precomputation.powers_per_generator[i][j]; - } - y = y + u; - } - - y -} - -pub fn msm_execute_parallel( - precomputation: &MsmPrecomputation, - scalars: &[C::ScalarField], -) -> ProjectivePoint { - assert_eq!(precomputation.powers_per_generator.len(), scalars.len()); - let w = precomputation.w; - let digits = (C::ScalarField::BITS + w - 1) / w; - let base = 1 << w; - - // This is a variant of Yao's method, adapted to the multi-scalar setting. Because we use - // extremely large windows, the repeated scans in Yao's method could be more expensive than the - // actual group operations. To avoid this, we store a multimap from each possible digit to the - // positions in which that digit occurs in the scalars. These positions have the form (i, j), - // where i is the index of the generator and j is an index into the digits of the scalar - // associated with that generator. - let mut digit_occurrences: Vec> = Vec::with_capacity(digits); - for _i in 0..base { - digit_occurrences.push(Vec::new()); - } - for (i, scalar) in scalars.iter().enumerate() { - let digits = to_digits::(scalar, w); - for (j, &digit) in digits.iter().enumerate() { - digit_occurrences[digit].push((i, j)); - } - } - - // For each digit, we add up the powers associated with all occurrences that digit. - let digits: Vec = (0..base).collect(); - let digit_acc: Vec> = digits - .par_chunks(DIGITS_PER_CHUNK) - .flat_map(|chunk| { - let summations: Vec>> = chunk - .iter() - .map(|&digit| { - digit_occurrences[digit] - .iter() - .map(|&(i, j)| precomputation.powers_per_generator[i][j]) - .collect() - }) - .collect(); - affine_multisummation_best(summations) - }) - .collect(); - // println!("Computing the per-digit summations (in parallel) took {}s", start.elapsed().as_secs_f64()); - - let mut y = ProjectivePoint::ZERO; - let mut u = ProjectivePoint::ZERO; - for digit in (1..base).rev() { - u = u + digit_acc[digit]; - y = y + u; - } - // println!("Final summation (sequential) {}s", start.elapsed().as_secs_f64()); - y -} - -pub(crate) fn to_digits(x: &C::ScalarField, w: usize) -> Vec { - let scalar_bits = C::ScalarField::BITS; - let num_digits = (scalar_bits + w - 1) / w; - - // Convert x to a bool array. - let x_canonical: Vec<_> = x - .to_canonical_biguint() - .to_u64_digits() - .iter() - .cloned() - .pad_using(scalar_bits / 64, |_| 0) - .collect(); - let mut x_bits = Vec::with_capacity(scalar_bits); - for i in 0..scalar_bits { - x_bits.push((x_canonical[i / 64] >> (i as u64 % 64) & 1) != 0); - } - - let mut digits = Vec::with_capacity(num_digits); - for i in 0..num_digits { - let mut digit = 0; - for j in ((i * w)..((i + 1) * w).min(scalar_bits)).rev() { - digit <<= 1; - digit |= x_bits[j] as usize; - } - digits.push(digit); - } - digits -} - -#[cfg(test)] -mod tests { - use alloc::vec; - - use num::BigUint; - use plonky2::field::secp256k1_scalar::Secp256K1Scalar; - - use super::*; - use crate::curve::secp256k1::Secp256K1; - - #[test] - fn test_to_digits() { - let x_canonical = [ - 0b10101010101010101010101010101010, - 0b10101010101010101010101010101010, - 0b11001100110011001100110011001100, - 0b11001100110011001100110011001100, - 0b11110000111100001111000011110000, - 0b11110000111100001111000011110000, - 0b00001111111111111111111111111111, - 0b11111111111111111111111111111111, - ]; - let x = Secp256K1Scalar::from_noncanonical_biguint(BigUint::from_slice(&x_canonical)); - assert_eq!(x.to_canonical_biguint().to_u32_digits(), x_canonical); - assert_eq!( - to_digits::(&x, 17), - vec![ - 0b01010101010101010, - 0b10101010101010101, - 0b01010101010101010, - 0b11001010101010101, - 0b01100110011001100, - 0b00110011001100110, - 0b10011001100110011, - 0b11110000110011001, - 0b01111000011110000, - 0b00111100001111000, - 0b00011110000111100, - 0b11111111111111110, - 0b01111111111111111, - 0b11111111111111000, - 0b11111111111111111, - 0b1, - ] - ); - } - - #[test] - fn test_msm() { - let w = 5; - - let generator_1 = Secp256K1::GENERATOR_PROJECTIVE; - let generator_2 = generator_1 + generator_1; - let generator_3 = generator_1 + generator_2; - - let scalar_1 = Secp256K1Scalar::from_noncanonical_biguint(BigUint::from_slice(&[ - 11111111, 22222222, 33333333, 44444444, - ])); - let scalar_2 = Secp256K1Scalar::from_noncanonical_biguint(BigUint::from_slice(&[ - 22222222, 22222222, 33333333, 44444444, - ])); - let scalar_3 = Secp256K1Scalar::from_noncanonical_biguint(BigUint::from_slice(&[ - 33333333, 22222222, 33333333, 44444444, - ])); - - let generators = vec![generator_1, generator_2, generator_3]; - let scalars = vec![scalar_1, scalar_2, scalar_3]; - - let precomputation = msm_precompute(&generators, w); - let result_msm = msm_execute(&precomputation, &scalars); - - let result_naive = Secp256K1::convert(scalar_1) * generator_1 - + Secp256K1::convert(scalar_2) * generator_2 - + Secp256K1::convert(scalar_3) * generator_3; - - assert_eq!(result_msm, result_naive); - } -} diff --git a/ecdsa/src/curve/curve_multiplication.rs b/ecdsa/src/curve/curve_multiplication.rs deleted file mode 100644 index 1f9c653d..00000000 --- a/ecdsa/src/curve/curve_multiplication.rs +++ /dev/null @@ -1,100 +0,0 @@ -use alloc::vec::Vec; -use core::ops::Mul; - -use plonky2::field::types::{Field, PrimeField}; - -use crate::curve::curve_types::{Curve, CurveScalar, ProjectivePoint}; - -const WINDOW_BITS: usize = 4; -const BASE: usize = 1 << WINDOW_BITS; - -fn digits_per_scalar() -> usize { - (C::ScalarField::BITS + WINDOW_BITS - 1) / WINDOW_BITS -} - -/// Precomputed state used for scalar x ProjectivePoint multiplications, -/// specific to a particular generator. -#[derive(Clone)] -pub struct MultiplicationPrecomputation { - /// [(2^w)^i] g for each i < digits_per_scalar. - powers: Vec>, -} - -impl ProjectivePoint { - pub fn mul_precompute(&self) -> MultiplicationPrecomputation { - let num_digits = digits_per_scalar::(); - let mut powers = Vec::with_capacity(num_digits); - powers.push(*self); - for i in 1..num_digits { - let mut power_i = powers[i - 1]; - for _j in 0..WINDOW_BITS { - power_i = power_i.double(); - } - powers.push(power_i); - } - - MultiplicationPrecomputation { powers } - } - - #[must_use] - pub fn mul_with_precomputation( - &self, - scalar: C::ScalarField, - precomputation: MultiplicationPrecomputation, - ) -> Self { - // Yao's method; see https://koclab.cs.ucsb.edu/teaching/ecc/eccPapers/Doche-ch09.pdf - let precomputed_powers = precomputation.powers; - - let digits = to_digits::(&scalar); - - let mut y = ProjectivePoint::ZERO; - let mut u = ProjectivePoint::ZERO; - let mut all_summands = Vec::new(); - for j in (1..BASE).rev() { - let mut u_summands = Vec::new(); - for (i, &digit) in digits.iter().enumerate() { - if digit == j as u64 { - u_summands.push(precomputed_powers[i]); - } - } - all_summands.push(u_summands); - } - - let all_sums: Vec> = all_summands - .iter() - .cloned() - .map(|vec| vec.iter().fold(ProjectivePoint::ZERO, |a, &b| a + b)) - .collect(); - for i in 0..all_sums.len() { - u = u + all_sums[i]; - y = y + u; - } - y - } -} - -impl Mul> for CurveScalar { - type Output = ProjectivePoint; - - fn mul(self, rhs: ProjectivePoint) -> Self::Output { - let precomputation = rhs.mul_precompute(); - rhs.mul_with_precomputation(self.0, precomputation) - } -} - -#[allow(clippy::assertions_on_constants)] -fn to_digits(x: &C::ScalarField) -> Vec { - debug_assert!( - 64 % WINDOW_BITS == 0, - "For simplicity, only power-of-two window sizes are handled for now" - ); - let digits_per_u64 = 64 / WINDOW_BITS; - let mut digits = Vec::with_capacity(digits_per_scalar::()); - for limb in x.to_canonical_biguint().to_u64_digits() { - for j in 0..digits_per_u64 { - digits.push((limb >> (j * WINDOW_BITS) as u64) % BASE as u64); - } - } - - digits -} diff --git a/ecdsa/src/curve/curve_summation.rs b/ecdsa/src/curve/curve_summation.rs deleted file mode 100644 index 7bb633af..00000000 --- a/ecdsa/src/curve/curve_summation.rs +++ /dev/null @@ -1,238 +0,0 @@ -use alloc::vec; -use alloc::vec::Vec; -use core::iter::Sum; - -use plonky2::field::ops::Square; -use plonky2::field::types::Field; - -use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint}; - -impl Sum> for ProjectivePoint { - fn sum>>(iter: I) -> ProjectivePoint { - let points: Vec<_> = iter.collect(); - affine_summation_best(points) - } -} - -impl Sum for ProjectivePoint { - fn sum>>(iter: I) -> ProjectivePoint { - iter.fold(ProjectivePoint::ZERO, |acc, x| acc + x) - } -} - -pub fn affine_summation_best(summation: Vec>) -> ProjectivePoint { - let result = affine_multisummation_best(vec![summation]); - debug_assert_eq!(result.len(), 1); - result[0] -} - -pub fn affine_multisummation_best( - summations: Vec>>, -) -> Vec> { - let pairwise_sums: usize = summations.iter().map(|summation| summation.len() / 2).sum(); - - // This threshold is chosen based on data from the summation benchmarks. - if pairwise_sums < 70 { - affine_multisummation_pairwise(summations) - } else { - affine_multisummation_batch_inversion(summations) - } -} - -/// Adds each pair of points using an affine + affine = projective formula, then adds up the -/// intermediate sums using a projective formula. -pub fn affine_multisummation_pairwise( - summations: Vec>>, -) -> Vec> { - summations - .into_iter() - .map(affine_summation_pairwise) - .collect() -} - -/// Adds each pair of points using an affine + affine = projective formula, then adds up the -/// intermediate sums using a projective formula. -pub fn affine_summation_pairwise(points: Vec>) -> ProjectivePoint { - let mut reduced_points: Vec> = Vec::new(); - for chunk in points.chunks(2) { - match chunk.len() { - 1 => reduced_points.push(chunk[0].to_projective()), - 2 => reduced_points.push(chunk[0] + chunk[1]), - _ => panic!(), - } - } - // TODO: Avoid copying (deref) - reduced_points - .iter() - .fold(ProjectivePoint::ZERO, |sum, x| sum + *x) -} - -/// Computes several summations of affine points by applying an affine group law, except that the -/// divisions are batched via Montgomery's trick. -pub fn affine_summation_batch_inversion( - summation: Vec>, -) -> ProjectivePoint { - let result = affine_multisummation_batch_inversion(vec![summation]); - debug_assert_eq!(result.len(), 1); - result[0] -} - -/// Computes several summations of affine points by applying an affine group law, except that the -/// divisions are batched via Montgomery's trick. -pub fn affine_multisummation_batch_inversion( - summations: Vec>>, -) -> Vec> { - let mut elements_to_invert = Vec::new(); - - // For each pair of points, (x1, y1) and (x2, y2), that we're going to add later, we want to - // invert either y (if the points are equal) or x1 - x2 (otherwise). We will use these later. - for summation in &summations { - let n = summation.len(); - // The special case for n=0 is to avoid underflow. - let range_end = if n == 0 { 0 } else { n - 1 }; - - for i in (0..range_end).step_by(2) { - let p1 = summation[i]; - let p2 = summation[i + 1]; - let AffinePoint { - x: x1, - y: y1, - zero: zero1, - } = p1; - let AffinePoint { - x: x2, - y: _y2, - zero: zero2, - } = p2; - - if zero1 || zero2 || p1 == -p2 { - // These are trivial cases where we won't need any inverse. - } else if p1 == p2 { - elements_to_invert.push(y1.double()); - } else { - elements_to_invert.push(x1 - x2); - } - } - } - - let inverses: Vec = - C::BaseField::batch_multiplicative_inverse(&elements_to_invert); - - let mut all_reduced_points = Vec::with_capacity(summations.len()); - let mut inverse_index = 0; - for summation in summations { - let n = summation.len(); - let mut reduced_points = Vec::with_capacity((n + 1) / 2); - - // The special case for n=0 is to avoid underflow. - let range_end = if n == 0 { 0 } else { n - 1 }; - - for i in (0..range_end).step_by(2) { - let p1 = summation[i]; - let p2 = summation[i + 1]; - let AffinePoint { - x: x1, - y: y1, - zero: zero1, - } = p1; - let AffinePoint { - x: x2, - y: y2, - zero: zero2, - } = p2; - - let sum = if zero1 { - p2 - } else if zero2 { - p1 - } else if p1 == -p2 { - AffinePoint::ZERO - } else { - // It's a non-trivial case where we need one of the inverses we computed earlier. - let inverse = inverses[inverse_index]; - inverse_index += 1; - - if p1 == p2 { - // This is the doubling case. - let mut numerator = x1.square().triple(); - if C::A.is_nonzero() { - numerator += C::A; - } - let quotient = numerator * inverse; - let x3 = quotient.square() - x1.double(); - let y3 = quotient * (x1 - x3) - y1; - AffinePoint::nonzero(x3, y3) - } else { - // This is the general case. We use the incomplete addition formulas 4.3 and 4.4. - let quotient = (y1 - y2) * inverse; - let x3 = quotient.square() - x1 - x2; - let y3 = quotient * (x1 - x3) - y1; - AffinePoint::nonzero(x3, y3) - } - }; - reduced_points.push(sum); - } - - // If n is odd, the last point was not part of a pair. - if n % 2 == 1 { - reduced_points.push(summation[n - 1]); - } - - all_reduced_points.push(reduced_points); - } - - // We should have consumed all of the inverses from the batch computation. - debug_assert_eq!(inverse_index, inverses.len()); - - // Recurse with our smaller set of points. - affine_multisummation_best(all_reduced_points) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::curve::secp256k1::Secp256K1; - - #[test] - fn test_pairwise_affine_summation() { - let g_affine = Secp256K1::GENERATOR_AFFINE; - let g2_affine = (g_affine + g_affine).to_affine(); - let g3_affine = (g_affine + g_affine + g_affine).to_affine(); - let g2_proj = g2_affine.to_projective(); - let g3_proj = g3_affine.to_projective(); - assert_eq!( - affine_summation_pairwise::(vec![g_affine, g_affine]), - g2_proj - ); - assert_eq!( - affine_summation_pairwise::(vec![g_affine, g2_affine]), - g3_proj - ); - assert_eq!( - affine_summation_pairwise::(vec![g_affine, g_affine, g_affine]), - g3_proj - ); - assert_eq!( - affine_summation_pairwise::(vec![]), - ProjectivePoint::ZERO - ); - } - - #[test] - fn test_pairwise_affine_summation_batch_inversion() { - let g = Secp256K1::GENERATOR_AFFINE; - let g_proj = g.to_projective(); - assert_eq!( - affine_summation_batch_inversion::(vec![g, g]), - g_proj + g_proj - ); - assert_eq!( - affine_summation_batch_inversion::(vec![g, g, g]), - g_proj + g_proj + g_proj - ); - assert_eq!( - affine_summation_batch_inversion::(vec![]), - ProjectivePoint::ZERO - ); - } -} diff --git a/ecdsa/src/curve/curve_types.rs b/ecdsa/src/curve/curve_types.rs deleted file mode 100644 index 91047393..00000000 --- a/ecdsa/src/curve/curve_types.rs +++ /dev/null @@ -1,286 +0,0 @@ -use alloc::vec::Vec; -use core::fmt::Debug; -use core::hash::{Hash, Hasher}; -use core::ops::Neg; - -use plonky2::field::ops::Square; -use plonky2::field::types::{Field, PrimeField}; -use serde::{Deserialize, Serialize}; - -// To avoid implementation conflicts from associated types, -// see https://github.com/rust-lang/rust/issues/20400 -pub struct CurveScalar(pub ::ScalarField); - -/// A short Weierstrass curve. -pub trait Curve: 'static + Sync + Sized + Copy + Debug { - type BaseField: PrimeField; - type ScalarField: PrimeField; - - const A: Self::BaseField; - const B: Self::BaseField; - - const GENERATOR_AFFINE: AffinePoint; - - const GENERATOR_PROJECTIVE: ProjectivePoint = ProjectivePoint { - x: Self::GENERATOR_AFFINE.x, - y: Self::GENERATOR_AFFINE.y, - z: Self::BaseField::ONE, - }; - - fn convert(x: Self::ScalarField) -> CurveScalar { - CurveScalar(x) - } - - fn is_safe_curve() -> bool { - // Added additional check to prevent using vulnerabilties in case a discriminant is equal to 0. - (Self::A.cube().double().double() + Self::B.square().triple().triple().triple()) - .is_nonzero() - } -} - -/// A point on a short Weierstrass curve, represented in affine coordinates. -#[derive(Copy, Clone, Debug, Deserialize, Serialize)] -pub struct AffinePoint { - pub x: C::BaseField, - pub y: C::BaseField, - pub zero: bool, -} - -impl AffinePoint { - pub const ZERO: Self = Self { - x: C::BaseField::ZERO, - y: C::BaseField::ZERO, - zero: true, - }; - - pub fn nonzero(x: C::BaseField, y: C::BaseField) -> Self { - let point = Self { x, y, zero: false }; - debug_assert!(point.is_valid()); - point - } - - pub fn is_valid(&self) -> bool { - let Self { x, y, zero } = *self; - zero || y.square() == x.cube() + C::A * x + C::B - } - - pub fn to_projective(&self) -> ProjectivePoint { - let Self { x, y, zero } = *self; - let z = if zero { - C::BaseField::ZERO - } else { - C::BaseField::ONE - }; - - ProjectivePoint { x, y, z } - } - - pub fn batch_to_projective(affine_points: &[Self]) -> Vec> { - affine_points.iter().map(Self::to_projective).collect() - } - - #[must_use] - pub fn double(&self) -> Self { - let AffinePoint { x: x1, y: y1, zero } = *self; - - if zero { - return AffinePoint::ZERO; - } - - let double_y = y1.double(); - let inv_double_y = double_y.inverse(); // (2y)^(-1) - let triple_xx = x1.square().triple(); // 3x^2 - let lambda = (triple_xx + C::A) * inv_double_y; - let x3 = lambda.square() - self.x.double(); - let y3 = lambda * (x1 - x3) - y1; - - Self { - x: x3, - y: y3, - zero: false, - } - } -} - -impl PartialEq for AffinePoint { - fn eq(&self, other: &Self) -> bool { - let AffinePoint { - x: x1, - y: y1, - zero: zero1, - } = *self; - let AffinePoint { - x: x2, - y: y2, - zero: zero2, - } = *other; - if zero1 || zero2 { - return zero1 == zero2; - } - x1 == x2 && y1 == y2 - } -} - -impl Eq for AffinePoint {} - -impl Hash for AffinePoint { - fn hash(&self, state: &mut H) { - if self.zero { - self.zero.hash(state); - } else { - self.x.hash(state); - self.y.hash(state); - } - } -} - -/// A point on a short Weierstrass curve, represented in projective coordinates. -#[derive(Copy, Clone, Debug)] -pub struct ProjectivePoint { - pub x: C::BaseField, - pub y: C::BaseField, - pub z: C::BaseField, -} - -impl ProjectivePoint { - pub const ZERO: Self = Self { - x: C::BaseField::ZERO, - y: C::BaseField::ONE, - z: C::BaseField::ZERO, - }; - - pub fn nonzero(x: C::BaseField, y: C::BaseField, z: C::BaseField) -> Self { - let point = Self { x, y, z }; - debug_assert!(point.is_valid()); - point - } - - pub fn is_valid(&self) -> bool { - let Self { x, y, z } = *self; - z.is_zero() || y.square() * z == x.cube() + C::A * x * z.square() + C::B * z.cube() - } - - pub fn to_affine(&self) -> AffinePoint { - let Self { x, y, z } = *self; - if z == C::BaseField::ZERO { - AffinePoint::ZERO - } else { - let z_inv = z.inverse(); - AffinePoint::nonzero(x * z_inv, y * z_inv) - } - } - - pub fn batch_to_affine(proj_points: &[Self]) -> Vec> { - let n = proj_points.len(); - let zs: Vec = proj_points.iter().map(|pp| pp.z).collect(); - let z_invs = C::BaseField::batch_multiplicative_inverse(&zs); - - let mut result = Vec::with_capacity(n); - for i in 0..n { - let Self { x, y, z } = proj_points[i]; - result.push(if z == C::BaseField::ZERO { - AffinePoint::ZERO - } else { - let z_inv = z_invs[i]; - AffinePoint::nonzero(x * z_inv, y * z_inv) - }); - } - result - } - - // From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/doubling/dbl-2007-bl - #[must_use] - pub fn double(&self) -> Self { - let Self { x, y, z } = *self; - if z == C::BaseField::ZERO { - return ProjectivePoint::ZERO; - } - - let xx = x.square(); - let zz = z.square(); - let mut w = xx.triple(); - if C::A.is_nonzero() { - w += C::A * zz; - } - let s = y.double() * z; - let r = y * s; - let rr = r.square(); - let b = (x + r).square() - (xx + rr); - let h = w.square() - b.double(); - let x3 = h * s; - let y3 = w * (b - h) - rr.double(); - let z3 = s.cube(); - Self { - x: x3, - y: y3, - z: z3, - } - } - - pub fn add_slices(a: &[Self], b: &[Self]) -> Vec { - assert_eq!(a.len(), b.len()); - a.iter() - .zip(b.iter()) - .map(|(&a_i, &b_i)| a_i + b_i) - .collect() - } - - #[must_use] - pub fn neg(&self) -> Self { - Self { - x: self.x, - y: -self.y, - z: self.z, - } - } -} - -impl PartialEq for ProjectivePoint { - fn eq(&self, other: &Self) -> bool { - let ProjectivePoint { - x: x1, - y: y1, - z: z1, - } = *self; - let ProjectivePoint { - x: x2, - y: y2, - z: z2, - } = *other; - if z1 == C::BaseField::ZERO || z2 == C::BaseField::ZERO { - return z1 == z2; - } - - // We want to compare (x1/z1, y1/z1) == (x2/z2, y2/z2). - // But to avoid field division, it is better to compare (x1*z2, y1*z2) == (x2*z1, y2*z1). - x1 * z2 == x2 * z1 && y1 * z2 == y2 * z1 - } -} - -impl Eq for ProjectivePoint {} - -impl Neg for AffinePoint { - type Output = AffinePoint; - - fn neg(self) -> Self::Output { - let AffinePoint { x, y, zero } = self; - AffinePoint { x, y: -y, zero } - } -} - -impl Neg for ProjectivePoint { - type Output = ProjectivePoint; - - fn neg(self) -> Self::Output { - let ProjectivePoint { x, y, z } = self; - ProjectivePoint { x, y: -y, z } - } -} - -pub fn base_to_scalar(x: C::BaseField) -> C::ScalarField { - C::ScalarField::from_noncanonical_biguint(x.to_canonical_biguint()) -} - -pub fn scalar_to_base(x: C::ScalarField) -> C::BaseField { - C::BaseField::from_noncanonical_biguint(x.to_canonical_biguint()) -} diff --git a/ecdsa/src/curve/ecdsa.rs b/ecdsa/src/curve/ecdsa.rs deleted file mode 100644 index 131d8b4d..00000000 --- a/ecdsa/src/curve/ecdsa.rs +++ /dev/null @@ -1,84 +0,0 @@ -use plonky2::field::types::{Field, Sample}; -use serde::{Deserialize, Serialize}; - -use crate::curve::curve_msm::msm_parallel; -use crate::curve::curve_types::{base_to_scalar, AffinePoint, Curve, CurveScalar}; - -#[derive(Copy, Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] -pub struct ECDSASignature { - pub r: C::ScalarField, - pub s: C::ScalarField, -} - -#[derive(Copy, Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] -pub struct ECDSASecretKey(pub C::ScalarField); - -impl ECDSASecretKey { - pub fn to_public(&self) -> ECDSAPublicKey { - ECDSAPublicKey((CurveScalar(self.0) * C::GENERATOR_PROJECTIVE).to_affine()) - } -} - -#[derive(Copy, Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] -pub struct ECDSAPublicKey(pub AffinePoint); - -pub fn sign_message(msg: C::ScalarField, sk: ECDSASecretKey) -> ECDSASignature { - let (k, rr) = { - let mut k = C::ScalarField::rand(); - let mut rr = (CurveScalar(k) * C::GENERATOR_PROJECTIVE).to_affine(); - while rr.x == C::BaseField::ZERO { - k = C::ScalarField::rand(); - rr = (CurveScalar(k) * C::GENERATOR_PROJECTIVE).to_affine(); - } - (k, rr) - }; - let r = base_to_scalar::(rr.x); - - let s = k.inverse() * (msg + r * sk.0); - - ECDSASignature { r, s } -} - -pub fn verify_message( - msg: C::ScalarField, - sig: ECDSASignature, - pk: ECDSAPublicKey, -) -> bool { - let ECDSASignature { r, s } = sig; - - assert!(pk.0.is_valid()); - - let c = s.inverse(); - let u1 = msg * c; - let u2 = r * c; - - let g = C::GENERATOR_PROJECTIVE; - let w = 5; // Experimentally fastest - let point_proj = msm_parallel(&[u1, u2], &[g, pk.0.to_projective()], w); - let point = point_proj.to_affine(); - - let x = base_to_scalar::(point.x); - r == x -} - -#[cfg(test)] -mod tests { - use plonky2::field::secp256k1_scalar::Secp256K1Scalar; - use plonky2::field::types::Sample; - - use crate::curve::ecdsa::{sign_message, verify_message, ECDSASecretKey}; - use crate::curve::secp256k1::Secp256K1; - - #[test] - fn test_ecdsa_native() { - type C = Secp256K1; - - let msg = Secp256K1Scalar::rand(); - let sk = ECDSASecretKey::(Secp256K1Scalar::rand()); - let pk = sk.to_public(); - - let sig = sign_message(msg, sk); - let result = verify_message(msg, sig, pk); - assert!(result); - } -} diff --git a/ecdsa/src/curve/glv.rs b/ecdsa/src/curve/glv.rs deleted file mode 100644 index 7c3e5de0..00000000 --- a/ecdsa/src/curve/glv.rs +++ /dev/null @@ -1,140 +0,0 @@ -use num::rational::Ratio; -use num::BigUint; -use plonky2::field::secp256k1_base::Secp256K1Base; -use plonky2::field::secp256k1_scalar::Secp256K1Scalar; -use plonky2::field::types::{Field, PrimeField}; - -use crate::curve::curve_msm::msm_parallel; -use crate::curve::curve_types::{AffinePoint, ProjectivePoint}; -use crate::curve::secp256k1::Secp256K1; - -pub const GLV_BETA: Secp256K1Base = Secp256K1Base([ - 13923278643952681454, - 11308619431505398165, - 7954561588662645993, - 8856726876819556112, -]); - -pub const GLV_S: Secp256K1Scalar = Secp256K1Scalar([ - 16069571880186789234, - 1310022930574435960, - 11900229862571533402, - 6008836872998760672, -]); - -const A1: Secp256K1Scalar = Secp256K1Scalar([16747920425669159701, 3496713202691238861, 0, 0]); - -const MINUS_B1: Secp256K1Scalar = - Secp256K1Scalar([8022177200260244675, 16448129721693014056, 0, 0]); - -const A2: Secp256K1Scalar = Secp256K1Scalar([6323353552219852760, 1498098850674701302, 1, 0]); - -const B2: Secp256K1Scalar = Secp256K1Scalar([16747920425669159701, 3496713202691238861, 0, 0]); - -/// Algorithm 15.41 in Handbook of Elliptic and Hyperelliptic Curve Cryptography. -/// Decompose a scalar `k` into two small scalars `k1, k2` with `|k1|, |k2| < √p` that satisfy -/// `k1 + s * k2 = k`. -/// Returns `(|k1|, |k2|, k1 < 0, k2 < 0)`. -pub fn decompose_secp256k1_scalar( - k: Secp256K1Scalar, -) -> (Secp256K1Scalar, Secp256K1Scalar, bool, bool) { - let p = Secp256K1Scalar::order(); - let c1_biguint = Ratio::new( - B2.to_canonical_biguint() * k.to_canonical_biguint(), - p.clone(), - ) - .round() - .to_integer(); - let c1 = Secp256K1Scalar::from_noncanonical_biguint(c1_biguint); - let c2_biguint = Ratio::new( - MINUS_B1.to_canonical_biguint() * k.to_canonical_biguint(), - p.clone(), - ) - .round() - .to_integer(); - let c2 = Secp256K1Scalar::from_noncanonical_biguint(c2_biguint); - - let k1_raw = k - c1 * A1 - c2 * A2; - let k2_raw = c1 * MINUS_B1 - c2 * B2; - debug_assert!(k1_raw + GLV_S * k2_raw == k); - - let two = BigUint::from_slice(&[2]); - let k1_neg = k1_raw.to_canonical_biguint() > p.clone() / two.clone(); - let k1 = if k1_neg { - Secp256K1Scalar::from_noncanonical_biguint(p.clone() - k1_raw.to_canonical_biguint()) - } else { - k1_raw - }; - let k2_neg = k2_raw.to_canonical_biguint() > p.clone() / two; - let k2 = if k2_neg { - Secp256K1Scalar::from_noncanonical_biguint(p - k2_raw.to_canonical_biguint()) - } else { - k2_raw - }; - - (k1, k2, k1_neg, k2_neg) -} - -/// See Section 15.2.1 in Handbook of Elliptic and Hyperelliptic Curve Cryptography. -/// GLV scalar multiplication `k * P = k1 * P + k2 * psi(P)`, where `k = k1 + s * k2` is the -/// decomposition computed in `decompose_secp256k1_scalar(k)` and `psi` is the Secp256k1 -/// endomorphism `psi: (x, y) |-> (beta * x, y)` equivalent to scalar multiplication by `s`. -pub fn glv_mul(p: ProjectivePoint, k: Secp256K1Scalar) -> ProjectivePoint { - let (k1, k2, k1_neg, k2_neg) = decompose_secp256k1_scalar(k); - - let p_affine = p.to_affine(); - let sp = AffinePoint:: { - x: p_affine.x * GLV_BETA, - y: p_affine.y, - zero: p_affine.zero, - }; - - let first = if k1_neg { p.neg() } else { p }; - let second = if k2_neg { - sp.to_projective().neg() - } else { - sp.to_projective() - }; - - msm_parallel(&[k1, k2], &[first, second], 5) -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - use plonky2::field::secp256k1_scalar::Secp256K1Scalar; - use plonky2::field::types::{Field, Sample}; - - use crate::curve::curve_types::{Curve, CurveScalar}; - use crate::curve::glv::{decompose_secp256k1_scalar, glv_mul, GLV_S}; - use crate::curve::secp256k1::Secp256K1; - - #[test] - fn test_glv_decompose() -> Result<()> { - let k = Secp256K1Scalar::rand(); - let (k1, k2, k1_neg, k2_neg) = decompose_secp256k1_scalar(k); - let one = Secp256K1Scalar::ONE; - let m1 = if k1_neg { -one } else { one }; - let m2 = if k2_neg { -one } else { one }; - - assert!(k1 * m1 + GLV_S * k2 * m2 == k); - - Ok(()) - } - - #[test] - fn test_glv_mul() -> Result<()> { - for _ in 0..20 { - let k = Secp256K1Scalar::rand(); - - let p = CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE; - - let kp = CurveScalar(k) * p; - let glv = glv_mul(p, k); - - assert!(kp == glv); - } - - Ok(()) - } -} diff --git a/ecdsa/src/curve/mod.rs b/ecdsa/src/curve/mod.rs deleted file mode 100644 index 1984b0c6..00000000 --- a/ecdsa/src/curve/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -pub mod curve_adds; -pub mod curve_msm; -pub mod curve_multiplication; -pub mod curve_summation; -pub mod curve_types; -pub mod ecdsa; -pub mod glv; -pub mod secp256k1; diff --git a/ecdsa/src/curve/secp256k1.rs b/ecdsa/src/curve/secp256k1.rs deleted file mode 100644 index 0b899a71..00000000 --- a/ecdsa/src/curve/secp256k1.rs +++ /dev/null @@ -1,100 +0,0 @@ -use plonky2::field::secp256k1_base::Secp256K1Base; -use plonky2::field::secp256k1_scalar::Secp256K1Scalar; -use plonky2::field::types::Field; -use serde::{Deserialize, Serialize}; - -use crate::curve::curve_types::{AffinePoint, Curve}; - -#[derive(Debug, Copy, Clone, Deserialize, Eq, Hash, PartialEq, Serialize)] -pub struct Secp256K1; - -impl Curve for Secp256K1 { - type BaseField = Secp256K1Base; - type ScalarField = Secp256K1Scalar; - - const A: Secp256K1Base = Secp256K1Base::ZERO; - const B: Secp256K1Base = Secp256K1Base([7, 0, 0, 0]); - const GENERATOR_AFFINE: AffinePoint = AffinePoint { - x: SECP256K1_GENERATOR_X, - y: SECP256K1_GENERATOR_Y, - zero: false, - }; -} - -// 55066263022277343669578718895168534326250603453777594175500187360389116729240 -const SECP256K1_GENERATOR_X: Secp256K1Base = Secp256K1Base([ - 0x59F2815B16F81798, - 0x029BFCDB2DCE28D9, - 0x55A06295CE870B07, - 0x79BE667EF9DCBBAC, -]); - -/// 32670510020758816978083085130507043184471273380659243275938904335757337482424 -const SECP256K1_GENERATOR_Y: Secp256K1Base = Secp256K1Base([ - 0x9C47D08FFB10D4B8, - 0xFD17B448A6855419, - 0x5DA4FBFC0E1108A8, - 0x483ADA7726A3C465, -]); - -#[cfg(test)] -mod tests { - use num::BigUint; - use plonky2::field::secp256k1_scalar::Secp256K1Scalar; - use plonky2::field::types::{Field, PrimeField}; - - use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint}; - use crate::curve::secp256k1::Secp256K1; - - #[test] - fn test_generator() { - let g = Secp256K1::GENERATOR_AFFINE; - assert!(g.is_valid()); - - let neg_g = AffinePoint:: { - x: g.x, - y: -g.y, - zero: g.zero, - }; - assert!(neg_g.is_valid()); - } - - #[test] - fn test_naive_multiplication() { - let g = Secp256K1::GENERATOR_PROJECTIVE; - let ten = Secp256K1Scalar::from_canonical_u64(10); - let product = mul_naive(ten, g); - let sum = g + g + g + g + g + g + g + g + g + g; - assert_eq!(product, sum); - } - - #[test] - fn test_g1_multiplication() { - let lhs = Secp256K1Scalar::from_noncanonical_biguint(BigUint::from_slice(&[ - 1111, 2222, 3333, 4444, 5555, 6666, 7777, 8888, - ])); - assert_eq!( - Secp256K1::convert(lhs) * Secp256K1::GENERATOR_PROJECTIVE, - mul_naive(lhs, Secp256K1::GENERATOR_PROJECTIVE) - ); - } - - /// A simple, somewhat inefficient implementation of multiplication which is used as a reference - /// for correctness. - fn mul_naive( - lhs: Secp256K1Scalar, - rhs: ProjectivePoint, - ) -> ProjectivePoint { - let mut g = rhs; - let mut sum = ProjectivePoint::ZERO; - for limb in lhs.to_canonical_biguint().to_u64_digits().iter() { - for j in 0..64 { - if (limb >> j & 1u64) != 0u64 { - sum = sum + g; - } - g = g.double(); - } - } - sum - } -} diff --git a/ecdsa/src/gadgets/biguint.rs b/ecdsa/src/gadgets/biguint.rs deleted file mode 100644 index 59e48d01..00000000 --- a/ecdsa/src/gadgets/biguint.rs +++ /dev/null @@ -1,508 +0,0 @@ -use alloc::vec; -use alloc::vec::Vec; -use core::marker::PhantomData; - -use num::{BigUint, Integer, Zero}; -use plonky2::field::extension::Extendable; -use plonky2::field::types::{PrimeField, PrimeField64}; -use plonky2::hash::hash_types::RichField; -use plonky2::iop::generator::{GeneratedValues, SimpleGenerator}; -use plonky2::iop::target::{BoolTarget, Target}; -use plonky2::iop::witness::{PartitionWitness, Witness}; -use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2_u32::gadgets::arithmetic_u32::{CircuitBuilderU32, U32Target}; -use plonky2_u32::gadgets::multiple_comparison::list_le_u32_circuit; -use plonky2_u32::witness::{GeneratedValuesU32, WitnessU32}; - -#[derive(Clone, Debug)] -pub struct BigUintTarget { - pub limbs: Vec, -} - -impl BigUintTarget { - pub fn num_limbs(&self) -> usize { - self.limbs.len() - } - - pub fn get_limb(&self, i: usize) -> U32Target { - self.limbs[i] - } -} - -pub trait CircuitBuilderBiguint, const D: usize> { - fn constant_biguint(&mut self, value: &BigUint) -> BigUintTarget; - - fn zero_biguint(&mut self) -> BigUintTarget; - - fn connect_biguint(&mut self, lhs: &BigUintTarget, rhs: &BigUintTarget); - - fn pad_biguints( - &mut self, - a: &BigUintTarget, - b: &BigUintTarget, - ) -> (BigUintTarget, BigUintTarget); - - fn cmp_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BoolTarget; - - fn add_virtual_biguint_target(&mut self, num_limbs: usize) -> BigUintTarget; - - /// Add two `BigUintTarget`s. - fn add_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget; - - /// Subtract two `BigUintTarget`s. We assume that the first is larger than the second. - fn sub_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget; - - fn mul_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget; - - fn mul_biguint_by_bool(&mut self, a: &BigUintTarget, b: BoolTarget) -> BigUintTarget; - - /// Returns x * y + z. This is no more efficient than mul-then-add; it's purely for convenience (only need to call one CircuitBuilder function). - fn mul_add_biguint( - &mut self, - x: &BigUintTarget, - y: &BigUintTarget, - z: &BigUintTarget, - ) -> BigUintTarget; - - fn div_rem_biguint( - &mut self, - a: &BigUintTarget, - b: &BigUintTarget, - ) -> (BigUintTarget, BigUintTarget); - - fn div_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget; - - fn rem_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget; -} - -impl, const D: usize> CircuitBuilderBiguint - for CircuitBuilder -{ - fn constant_biguint(&mut self, value: &BigUint) -> BigUintTarget { - let limb_values = value.to_u32_digits(); - let limbs = limb_values.iter().map(|&l| self.constant_u32(l)).collect(); - - BigUintTarget { limbs } - } - - fn zero_biguint(&mut self) -> BigUintTarget { - self.constant_biguint(&BigUint::zero()) - } - - fn connect_biguint(&mut self, lhs: &BigUintTarget, rhs: &BigUintTarget) { - let min_limbs = lhs.num_limbs().min(rhs.num_limbs()); - for i in 0..min_limbs { - self.connect_u32(lhs.get_limb(i), rhs.get_limb(i)); - } - - for i in min_limbs..lhs.num_limbs() { - self.assert_zero_u32(lhs.get_limb(i)); - } - for i in min_limbs..rhs.num_limbs() { - self.assert_zero_u32(rhs.get_limb(i)); - } - } - - fn pad_biguints( - &mut self, - a: &BigUintTarget, - b: &BigUintTarget, - ) -> (BigUintTarget, BigUintTarget) { - if a.num_limbs() > b.num_limbs() { - let mut padded_b = b.clone(); - for _ in b.num_limbs()..a.num_limbs() { - padded_b.limbs.push(self.zero_u32()); - } - - (a.clone(), padded_b) - } else { - let mut padded_a = a.clone(); - for _ in a.num_limbs()..b.num_limbs() { - padded_a.limbs.push(self.zero_u32()); - } - - (padded_a, b.clone()) - } - } - - fn cmp_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BoolTarget { - let (a, b) = self.pad_biguints(a, b); - - list_le_u32_circuit(self, a.limbs, b.limbs) - } - - fn add_virtual_biguint_target(&mut self, num_limbs: usize) -> BigUintTarget { - let limbs = self.add_virtual_u32_targets(num_limbs); - - BigUintTarget { limbs } - } - - fn add_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { - let num_limbs = a.num_limbs().max(b.num_limbs()); - - let mut combined_limbs = vec![]; - let mut carry = self.zero_u32(); - for i in 0..num_limbs { - let a_limb = (i < a.num_limbs()) - .then(|| a.limbs[i]) - .unwrap_or_else(|| self.zero_u32()); - let b_limb = (i < b.num_limbs()) - .then(|| b.limbs[i]) - .unwrap_or_else(|| self.zero_u32()); - - let (new_limb, new_carry) = self.add_many_u32(&[carry, a_limb, b_limb]); - carry = new_carry; - combined_limbs.push(new_limb); - } - combined_limbs.push(carry); - - BigUintTarget { - limbs: combined_limbs, - } - } - - fn sub_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { - let (a, b) = self.pad_biguints(a, b); - let num_limbs = a.limbs.len(); - - let mut result_limbs = vec![]; - - let mut borrow = self.zero_u32(); - for i in 0..num_limbs { - let (result, new_borrow) = self.sub_u32(a.limbs[i], b.limbs[i], borrow); - result_limbs.push(result); - borrow = new_borrow; - } - // Borrow should be zero here. - - BigUintTarget { - limbs: result_limbs, - } - } - - fn mul_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { - let total_limbs = a.limbs.len() + b.limbs.len(); - - let mut to_add = vec![vec![]; total_limbs]; - for i in 0..a.limbs.len() { - for j in 0..b.limbs.len() { - let (product, carry) = self.mul_u32(a.limbs[i], b.limbs[j]); - to_add[i + j].push(product); - to_add[i + j + 1].push(carry); - } - } - - let mut combined_limbs = vec![]; - let mut carry = self.zero_u32(); - for summands in &mut to_add { - let (new_result, new_carry) = self.add_u32s_with_carry(summands, carry); - combined_limbs.push(new_result); - carry = new_carry; - } - combined_limbs.push(carry); - - BigUintTarget { - limbs: combined_limbs, - } - } - - fn mul_biguint_by_bool(&mut self, a: &BigUintTarget, b: BoolTarget) -> BigUintTarget { - let t = b.target; - - BigUintTarget { - limbs: a - .limbs - .iter() - .map(|&l| U32Target(self.mul(l.0, t))) - .collect(), - } - } - - fn mul_add_biguint( - &mut self, - x: &BigUintTarget, - y: &BigUintTarget, - z: &BigUintTarget, - ) -> BigUintTarget { - let prod = self.mul_biguint(x, y); - self.add_biguint(&prod, z) - } - - fn div_rem_biguint( - &mut self, - a: &BigUintTarget, - b: &BigUintTarget, - ) -> (BigUintTarget, BigUintTarget) { - let a_len = a.limbs.len(); - let b_len = b.limbs.len(); - let div_num_limbs = if b_len > a_len + 1 { - 0 - } else { - a_len - b_len + 1 - }; - let div = self.add_virtual_biguint_target(div_num_limbs); - let rem = self.add_virtual_biguint_target(b_len); - - self.add_simple_generator(BigUintDivRemGenerator:: { - a: a.clone(), - b: b.clone(), - div: div.clone(), - rem: rem.clone(), - _phantom: PhantomData, - }); - - let div_b = self.mul_biguint(&div, b); - let div_b_plus_rem = self.add_biguint(&div_b, &rem); - self.connect_biguint(a, &div_b_plus_rem); - - let cmp_rem_b = self.cmp_biguint(&rem, b); - self.assert_one(cmp_rem_b.target); - - (div, rem) - } - - fn div_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { - let (div, _rem) = self.div_rem_biguint(a, b); - div - } - - fn rem_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { - let (_div, rem) = self.div_rem_biguint(a, b); - rem - } -} - -pub trait WitnessBigUint: Witness { - fn get_biguint_target(&self, target: BigUintTarget) -> BigUint; - fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint); -} - -impl, F: PrimeField64> WitnessBigUint for T { - fn get_biguint_target(&self, target: BigUintTarget) -> BigUint { - target - .limbs - .into_iter() - .rev() - .fold(BigUint::zero(), |acc, limb| { - (acc << 32) + self.get_target(limb.0).to_canonical_biguint() - }) - } - - fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint) { - let mut limbs = value.to_u32_digits(); - assert!(target.num_limbs() >= limbs.len()); - limbs.resize(target.num_limbs(), 0); - for i in 0..target.num_limbs() { - self.set_u32_target(target.limbs[i], limbs[i]); - } - } -} - -pub trait GeneratedValuesBigUint { - fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint); -} - -impl GeneratedValuesBigUint for GeneratedValues { - fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint) { - let mut limbs = value.to_u32_digits(); - assert!(target.num_limbs() >= limbs.len()); - limbs.resize(target.num_limbs(), 0); - for i in 0..target.num_limbs() { - self.set_u32_target(target.get_limb(i), limbs[i]); - } - } -} - -#[derive(Debug)] -struct BigUintDivRemGenerator, const D: usize> { - a: BigUintTarget, - b: BigUintTarget, - div: BigUintTarget, - rem: BigUintTarget, - _phantom: PhantomData, -} - -impl, const D: usize> SimpleGenerator - for BigUintDivRemGenerator -{ - fn dependencies(&self) -> Vec { - self.a - .limbs - .iter() - .chain(&self.b.limbs) - .map(|&l| l.0) - .collect() - } - - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let a = witness.get_biguint_target(self.a.clone()); - let b = witness.get_biguint_target(self.b.clone()); - let (div, rem) = a.div_rem(&b); - - out_buffer.set_biguint_target(&self.div, &div); - out_buffer.set_biguint_target(&self.rem, &rem); - } -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - use num::{BigUint, FromPrimitive, Integer}; - use plonky2::iop::witness::PartialWitness; - use plonky2::plonk::circuit_builder::CircuitBuilder; - use plonky2::plonk::circuit_data::CircuitConfig; - use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use rand::rngs::OsRng; - use rand::Rng; - - use crate::gadgets::biguint::{CircuitBuilderBiguint, WitnessBigUint}; - - #[test] - fn test_biguint_add() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - let mut rng = OsRng; - - let x_value = BigUint::from_u128(rng.gen()).unwrap(); - let y_value = BigUint::from_u128(rng.gen()).unwrap(); - let expected_z_value = &x_value + &y_value; - - let config = CircuitConfig::standard_recursion_config(); - let mut pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let x = builder.add_virtual_biguint_target(x_value.to_u32_digits().len()); - let y = builder.add_virtual_biguint_target(y_value.to_u32_digits().len()); - let z = builder.add_biguint(&x, &y); - let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len()); - builder.connect_biguint(&z, &expected_z); - - pw.set_biguint_target(&x, &x_value); - pw.set_biguint_target(&y, &y_value); - pw.set_biguint_target(&expected_z, &expected_z_value); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - data.verify(proof) - } - - #[test] - fn test_biguint_sub() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - let mut rng = OsRng; - - let mut x_value = BigUint::from_u128(rng.gen()).unwrap(); - let mut y_value = BigUint::from_u128(rng.gen()).unwrap(); - if y_value > x_value { - (x_value, y_value) = (y_value, x_value); - } - let expected_z_value = &x_value - &y_value; - - let config = CircuitConfig::standard_recursion_config(); - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let x = builder.constant_biguint(&x_value); - let y = builder.constant_biguint(&y_value); - let z = builder.sub_biguint(&x, &y); - let expected_z = builder.constant_biguint(&expected_z_value); - - builder.connect_biguint(&z, &expected_z); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - data.verify(proof) - } - - #[test] - fn test_biguint_mul() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - let mut rng = OsRng; - - let x_value = BigUint::from_u128(rng.gen()).unwrap(); - let y_value = BigUint::from_u128(rng.gen()).unwrap(); - let expected_z_value = &x_value * &y_value; - - let config = CircuitConfig::standard_recursion_config(); - let mut pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let x = builder.add_virtual_biguint_target(x_value.to_u32_digits().len()); - let y = builder.add_virtual_biguint_target(y_value.to_u32_digits().len()); - let z = builder.mul_biguint(&x, &y); - let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len()); - builder.connect_biguint(&z, &expected_z); - - pw.set_biguint_target(&x, &x_value); - pw.set_biguint_target(&y, &y_value); - pw.set_biguint_target(&expected_z, &expected_z_value); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - data.verify(proof) - } - - #[test] - fn test_biguint_cmp() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - let mut rng = OsRng; - - let x_value = BigUint::from_u128(rng.gen()).unwrap(); - let y_value = BigUint::from_u128(rng.gen()).unwrap(); - - let config = CircuitConfig::standard_recursion_config(); - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let x = builder.constant_biguint(&x_value); - let y = builder.constant_biguint(&y_value); - let cmp = builder.cmp_biguint(&x, &y); - let expected_cmp = builder.constant_bool(x_value <= y_value); - - builder.connect(cmp.target, expected_cmp.target); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - data.verify(proof) - } - - #[test] - fn test_biguint_div_rem() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - let mut rng = OsRng; - - let mut x_value = BigUint::from_u128(rng.gen()).unwrap(); - let mut y_value = BigUint::from_u128(rng.gen()).unwrap(); - if y_value > x_value { - (x_value, y_value) = (y_value, x_value); - } - let (expected_div_value, expected_rem_value) = x_value.div_rem(&y_value); - - let config = CircuitConfig::standard_recursion_config(); - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let x = builder.constant_biguint(&x_value); - let y = builder.constant_biguint(&y_value); - let (div, rem) = builder.div_rem_biguint(&x, &y); - - let expected_div = builder.constant_biguint(&expected_div_value); - let expected_rem = builder.constant_biguint(&expected_rem_value); - - builder.connect_biguint(&div, &expected_div); - builder.connect_biguint(&rem, &expected_rem); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - data.verify(proof) - } -} diff --git a/ecdsa/src/gadgets/curve.rs b/ecdsa/src/gadgets/curve.rs deleted file mode 100644 index 11075322..00000000 --- a/ecdsa/src/gadgets/curve.rs +++ /dev/null @@ -1,486 +0,0 @@ -use alloc::vec; -use alloc::vec::Vec; - -use plonky2::field::extension::Extendable; -use plonky2::field::types::Sample; -use plonky2::hash::hash_types::RichField; -use plonky2::iop::target::BoolTarget; -use plonky2::plonk::circuit_builder::CircuitBuilder; - -use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; -use crate::gadgets::nonnative::{CircuitBuilderNonNative, NonNativeTarget}; - -/// A Target representing an affine point on the curve `C`. We use incomplete arithmetic for efficiency, -/// so we assume these points are not zero. -#[derive(Clone, Debug)] -pub struct AffinePointTarget { - pub x: NonNativeTarget, - pub y: NonNativeTarget, -} - -impl AffinePointTarget { - pub fn to_vec(&self) -> Vec> { - vec![self.x.clone(), self.y.clone()] - } -} - -pub trait CircuitBuilderCurve, const D: usize> { - fn constant_affine_point(&mut self, point: AffinePoint) -> AffinePointTarget; - - fn connect_affine_point( - &mut self, - lhs: &AffinePointTarget, - rhs: &AffinePointTarget, - ); - - fn add_virtual_affine_point_target(&mut self) -> AffinePointTarget; - - fn curve_assert_valid(&mut self, p: &AffinePointTarget); - - fn curve_neg(&mut self, p: &AffinePointTarget) -> AffinePointTarget; - - fn curve_conditional_neg( - &mut self, - p: &AffinePointTarget, - b: BoolTarget, - ) -> AffinePointTarget; - - fn curve_double(&mut self, p: &AffinePointTarget) -> AffinePointTarget; - - fn curve_repeated_double( - &mut self, - p: &AffinePointTarget, - n: usize, - ) -> AffinePointTarget; - - /// Add two points, which are assumed to be non-equal. - fn curve_add( - &mut self, - p1: &AffinePointTarget, - p2: &AffinePointTarget, - ) -> AffinePointTarget; - - fn curve_conditional_add( - &mut self, - p1: &AffinePointTarget, - p2: &AffinePointTarget, - b: BoolTarget, - ) -> AffinePointTarget; - - fn curve_scalar_mul( - &mut self, - p: &AffinePointTarget, - n: &NonNativeTarget, - ) -> AffinePointTarget; -} - -impl, const D: usize> CircuitBuilderCurve - for CircuitBuilder -{ - fn constant_affine_point(&mut self, point: AffinePoint) -> AffinePointTarget { - debug_assert!(!point.zero); - AffinePointTarget { - x: self.constant_nonnative(point.x), - y: self.constant_nonnative(point.y), - } - } - - fn connect_affine_point( - &mut self, - lhs: &AffinePointTarget, - rhs: &AffinePointTarget, - ) { - self.connect_nonnative(&lhs.x, &rhs.x); - self.connect_nonnative(&lhs.y, &rhs.y); - } - - fn add_virtual_affine_point_target(&mut self) -> AffinePointTarget { - let x = self.add_virtual_nonnative_target(); - let y = self.add_virtual_nonnative_target(); - - AffinePointTarget { x, y } - } - - fn curve_assert_valid(&mut self, p: &AffinePointTarget) { - let a = self.constant_nonnative(C::A); - let b = self.constant_nonnative(C::B); - - let y_squared = self.mul_nonnative(&p.y, &p.y); - let x_squared = self.mul_nonnative(&p.x, &p.x); - let x_cubed = self.mul_nonnative(&x_squared, &p.x); - let a_x = self.mul_nonnative(&a, &p.x); - let a_x_plus_b = self.add_nonnative(&a_x, &b); - let rhs = self.add_nonnative(&x_cubed, &a_x_plus_b); - - self.connect_nonnative(&y_squared, &rhs); - } - - fn curve_neg(&mut self, p: &AffinePointTarget) -> AffinePointTarget { - let neg_y = self.neg_nonnative(&p.y); - AffinePointTarget { - x: p.x.clone(), - y: neg_y, - } - } - - fn curve_conditional_neg( - &mut self, - p: &AffinePointTarget, - b: BoolTarget, - ) -> AffinePointTarget { - AffinePointTarget { - x: p.x.clone(), - y: self.nonnative_conditional_neg(&p.y, b), - } - } - - fn curve_double(&mut self, p: &AffinePointTarget) -> AffinePointTarget { - let AffinePointTarget { x, y } = p; - let double_y = self.add_nonnative(y, y); - let inv_double_y = self.inv_nonnative(&double_y); - let x_squared = self.mul_nonnative(x, x); - let double_x_squared = self.add_nonnative(&x_squared, &x_squared); - let triple_x_squared = self.add_nonnative(&double_x_squared, &x_squared); - - let a = self.constant_nonnative(C::A); - let triple_xx_a = self.add_nonnative(&triple_x_squared, &a); - let lambda = self.mul_nonnative(&triple_xx_a, &inv_double_y); - let lambda_squared = self.mul_nonnative(&lambda, &lambda); - let x_double = self.add_nonnative(x, x); - - let x3 = self.sub_nonnative(&lambda_squared, &x_double); - - let x_diff = self.sub_nonnative(x, &x3); - let lambda_x_diff = self.mul_nonnative(&lambda, &x_diff); - - let y3 = self.sub_nonnative(&lambda_x_diff, y); - - AffinePointTarget { x: x3, y: y3 } - } - - fn curve_repeated_double( - &mut self, - p: &AffinePointTarget, - n: usize, - ) -> AffinePointTarget { - let mut result = p.clone(); - - for _ in 0..n { - result = self.curve_double(&result); - } - - result - } - - fn curve_add( - &mut self, - p1: &AffinePointTarget, - p2: &AffinePointTarget, - ) -> AffinePointTarget { - let AffinePointTarget { x: x1, y: y1 } = p1; - let AffinePointTarget { x: x2, y: y2 } = p2; - - let u = self.sub_nonnative(y2, y1); - let v = self.sub_nonnative(x2, x1); - let v_inv = self.inv_nonnative(&v); - let s = self.mul_nonnative(&u, &v_inv); - let s_squared = self.mul_nonnative(&s, &s); - let x_sum = self.add_nonnative(x2, x1); - let x3 = self.sub_nonnative(&s_squared, &x_sum); - let x_diff = self.sub_nonnative(x1, &x3); - let prod = self.mul_nonnative(&s, &x_diff); - let y3 = self.sub_nonnative(&prod, y1); - - AffinePointTarget { x: x3, y: y3 } - } - - fn curve_conditional_add( - &mut self, - p1: &AffinePointTarget, - p2: &AffinePointTarget, - b: BoolTarget, - ) -> AffinePointTarget { - let not_b = self.not(b); - let sum = self.curve_add(p1, p2); - let x_if_true = self.mul_nonnative_by_bool(&sum.x, b); - let y_if_true = self.mul_nonnative_by_bool(&sum.y, b); - let x_if_false = self.mul_nonnative_by_bool(&p1.x, not_b); - let y_if_false = self.mul_nonnative_by_bool(&p1.y, not_b); - - let x = self.add_nonnative(&x_if_true, &x_if_false); - let y = self.add_nonnative(&y_if_true, &y_if_false); - - AffinePointTarget { x, y } - } - - fn curve_scalar_mul( - &mut self, - p: &AffinePointTarget, - n: &NonNativeTarget, - ) -> AffinePointTarget { - let bits = self.split_nonnative_to_bits(n); - - let rando = (CurveScalar(C::ScalarField::rand()) * C::GENERATOR_PROJECTIVE).to_affine(); - let randot = self.constant_affine_point(rando); - // Result starts at `rando`, which is later subtracted, because we don't support arithmetic with the zero point. - let mut result = self.add_virtual_affine_point_target(); - self.connect_affine_point(&randot, &result); - - let mut two_i_times_p = self.add_virtual_affine_point_target(); - self.connect_affine_point(p, &two_i_times_p); - - for &bit in bits.iter() { - let not_bit = self.not(bit); - - let result_plus_2_i_p = self.curve_add(&result, &two_i_times_p); - - let new_x_if_bit = self.mul_nonnative_by_bool(&result_plus_2_i_p.x, bit); - let new_x_if_not_bit = self.mul_nonnative_by_bool(&result.x, not_bit); - let new_y_if_bit = self.mul_nonnative_by_bool(&result_plus_2_i_p.y, bit); - let new_y_if_not_bit = self.mul_nonnative_by_bool(&result.y, not_bit); - - let new_x = self.add_nonnative(&new_x_if_bit, &new_x_if_not_bit); - let new_y = self.add_nonnative(&new_y_if_bit, &new_y_if_not_bit); - - result = AffinePointTarget { x: new_x, y: new_y }; - - two_i_times_p = self.curve_double(&two_i_times_p); - } - - // Subtract off result's intial value of `rando`. - let neg_r = self.curve_neg(&randot); - result = self.curve_add(&result, &neg_r); - - result - } -} - -#[cfg(test)] -mod tests { - use core::ops::Neg; - - use anyhow::Result; - use plonky2::field::secp256k1_base::Secp256K1Base; - use plonky2::field::secp256k1_scalar::Secp256K1Scalar; - use plonky2::field::types::{Field, Sample}; - use plonky2::iop::witness::PartialWitness; - use plonky2::plonk::circuit_builder::CircuitBuilder; - use plonky2::plonk::circuit_data::CircuitConfig; - use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - - use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; - use crate::curve::secp256k1::Secp256K1; - use crate::gadgets::curve::CircuitBuilderCurve; - use crate::gadgets::nonnative::CircuitBuilderNonNative; - - #[test] - fn test_curve_point_is_valid() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_ecc_config(); - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let g = Secp256K1::GENERATOR_AFFINE; - let g_target = builder.constant_affine_point(g); - let neg_g_target = builder.curve_neg(&g_target); - - builder.curve_assert_valid(&g_target); - builder.curve_assert_valid(&neg_g_target); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - - data.verify(proof) - } - - #[test] - #[should_panic] - fn test_curve_point_is_not_valid() { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_ecc_config(); - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let g = Secp256K1::GENERATOR_AFFINE; - let not_g = AffinePoint:: { - x: g.x, - y: g.y + Secp256K1Base::ONE, - zero: g.zero, - }; - let not_g_target = builder.constant_affine_point(not_g); - - builder.curve_assert_valid(¬_g_target); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - - data.verify(proof).unwrap() - } - - #[test] - fn test_curve_double() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_ecc_config(); - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let g = Secp256K1::GENERATOR_AFFINE; - let g_target = builder.constant_affine_point(g); - let neg_g_target = builder.curve_neg(&g_target); - - let double_g = g.double(); - let double_g_expected = builder.constant_affine_point(double_g); - builder.curve_assert_valid(&double_g_expected); - - let double_neg_g = (-g).double(); - let double_neg_g_expected = builder.constant_affine_point(double_neg_g); - builder.curve_assert_valid(&double_neg_g_expected); - - let double_g_actual = builder.curve_double(&g_target); - let double_neg_g_actual = builder.curve_double(&neg_g_target); - builder.curve_assert_valid(&double_g_actual); - builder.curve_assert_valid(&double_neg_g_actual); - - builder.connect_affine_point(&double_g_expected, &double_g_actual); - builder.connect_affine_point(&double_neg_g_expected, &double_neg_g_actual); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - - data.verify(proof) - } - - #[test] - fn test_curve_add() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_ecc_config(); - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let g = Secp256K1::GENERATOR_AFFINE; - let double_g = g.double(); - let g_plus_2g = (g + double_g).to_affine(); - let g_plus_2g_expected = builder.constant_affine_point(g_plus_2g); - builder.curve_assert_valid(&g_plus_2g_expected); - - let g_target = builder.constant_affine_point(g); - let double_g_target = builder.curve_double(&g_target); - let g_plus_2g_actual = builder.curve_add(&g_target, &double_g_target); - builder.curve_assert_valid(&g_plus_2g_actual); - - builder.connect_affine_point(&g_plus_2g_expected, &g_plus_2g_actual); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - - data.verify(proof) - } - - #[test] - fn test_curve_conditional_add() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_ecc_config(); - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let g = Secp256K1::GENERATOR_AFFINE; - let double_g = g.double(); - let g_plus_2g = (g + double_g).to_affine(); - let g_plus_2g_expected = builder.constant_affine_point(g_plus_2g); - - let g_expected = builder.constant_affine_point(g); - let double_g_target = builder.curve_double(&g_expected); - let t = builder._true(); - let f = builder._false(); - let g_plus_2g_actual = builder.curve_conditional_add(&g_expected, &double_g_target, t); - let g_actual = builder.curve_conditional_add(&g_expected, &double_g_target, f); - - builder.connect_affine_point(&g_plus_2g_expected, &g_plus_2g_actual); - builder.connect_affine_point(&g_expected, &g_actual); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - - data.verify(proof) - } - - #[test] - #[ignore] - fn test_curve_mul() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_ecc_config(); - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let g = Secp256K1::GENERATOR_PROJECTIVE.to_affine(); - let five = Secp256K1Scalar::from_canonical_usize(5); - let neg_five = five.neg(); - let neg_five_scalar = CurveScalar::(neg_five); - let neg_five_g = (neg_five_scalar * g.to_projective()).to_affine(); - let neg_five_g_expected = builder.constant_affine_point(neg_five_g); - builder.curve_assert_valid(&neg_five_g_expected); - - let g_target = builder.constant_affine_point(g); - let neg_five_target = builder.constant_nonnative(neg_five); - let neg_five_g_actual = builder.curve_scalar_mul(&g_target, &neg_five_target); - builder.curve_assert_valid(&neg_five_g_actual); - - builder.connect_affine_point(&neg_five_g_expected, &neg_five_g_actual); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - - data.verify(proof) - } - - #[test] - #[ignore] - fn test_curve_random() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_ecc_config(); - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let rando = - (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); - let randot = builder.constant_affine_point(rando); - - let two_target = builder.constant_nonnative(Secp256K1Scalar::TWO); - let randot_doubled = builder.curve_double(&randot); - let randot_times_two = builder.curve_scalar_mul(&randot, &two_target); - builder.connect_affine_point(&randot_doubled, &randot_times_two); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - - data.verify(proof) - } -} diff --git a/ecdsa/src/gadgets/curve_fixed_base.rs b/ecdsa/src/gadgets/curve_fixed_base.rs deleted file mode 100644 index e7656f5c..00000000 --- a/ecdsa/src/gadgets/curve_fixed_base.rs +++ /dev/null @@ -1,118 +0,0 @@ -use alloc::vec::Vec; - -use num::BigUint; -use plonky2::field::extension::Extendable; -use plonky2::field::types::Field; -use plonky2::hash::hash_types::RichField; -use plonky2::hash::keccak::KeccakHash; -use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2::plonk::config::{GenericHashOut, Hasher}; - -use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; -use crate::gadgets::curve::{AffinePointTarget, CircuitBuilderCurve}; -use crate::gadgets::curve_windowed_mul::CircuitBuilderWindowedMul; -use crate::gadgets::nonnative::NonNativeTarget; -use crate::gadgets::split_nonnative::CircuitBuilderSplit; - -/// Compute windowed fixed-base scalar multiplication, using a 4-bit window. -pub fn fixed_base_curve_mul_circuit, const D: usize>( - builder: &mut CircuitBuilder, - base: AffinePoint, - scalar: &NonNativeTarget, -) -> AffinePointTarget { - // Holds `(16^i) * base` for `i=0..scalar.value.limbs.len() * 8`. - let scaled_base = (0..scalar.value.limbs.len() * 8).scan(base, |acc, _| { - let tmp = *acc; - for _ in 0..4 { - *acc = acc.double(); - } - Some(tmp) - }); - - let limbs = builder.split_nonnative_to_4_bit_limbs(scalar); - - let hash_0 = KeccakHash::<32>::hash_no_pad(&[F::ZERO]); - let hash_0_scalar = C::ScalarField::from_noncanonical_biguint(BigUint::from_bytes_le( - &GenericHashOut::::to_bytes(&hash_0), - )); - let rando = (CurveScalar(hash_0_scalar) * C::GENERATOR_PROJECTIVE).to_affine(); - - let zero = builder.zero(); - let mut result = builder.constant_affine_point(rando); - // `s * P = sum s_i * P_i` with `P_i = (16^i) * P` and `s = sum s_i * (16^i)`. - for (limb, point) in limbs.into_iter().zip(scaled_base) { - // `muls_point[t] = t * P_i` for `t=0..16`. - let mut muls_point = (0..16) - .scan(AffinePoint::ZERO, |acc, _| { - let tmp = *acc; - *acc = (point + *acc).to_affine(); - Some(tmp) - }) - // First element if zero, so we skip it since `constant_affine_point` takes non-zero input. - .skip(1) - .map(|p| builder.constant_affine_point(p)) - .collect::>(); - // We add back a point in position 0. `limb == zero` is checked below, so this point can be arbitrary. - muls_point.insert(0, muls_point[0].clone()); - let is_zero = builder.is_equal(limb, zero); - let should_add = builder.not(is_zero); - // `r = s_i * P_i` - let r = builder.random_access_curve_points(limb, muls_point); - result = builder.curve_conditional_add(&result, &r, should_add); - } - - let to_add = builder.constant_affine_point(-rando); - builder.curve_add(&result, &to_add) -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - use plonky2::field::secp256k1_scalar::Secp256K1Scalar; - use plonky2::field::types::{PrimeField, Sample}; - use plonky2::iop::witness::PartialWitness; - use plonky2::plonk::circuit_builder::CircuitBuilder; - use plonky2::plonk::circuit_data::CircuitConfig; - use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - - use crate::curve::curve_types::{Curve, CurveScalar}; - use crate::curve::secp256k1::Secp256K1; - use crate::gadgets::biguint::WitnessBigUint; - use crate::gadgets::curve::CircuitBuilderCurve; - use crate::gadgets::curve_fixed_base::fixed_base_curve_mul_circuit; - use crate::gadgets::nonnative::CircuitBuilderNonNative; - - #[test] - #[ignore] - fn test_fixed_base() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_ecc_config(); - - let mut pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let g = Secp256K1::GENERATOR_AFFINE; - let n = Secp256K1Scalar::rand(); - - let res = (CurveScalar(n) * g.to_projective()).to_affine(); - let res_expected = builder.constant_affine_point(res); - builder.curve_assert_valid(&res_expected); - - let n_target = builder.add_virtual_nonnative_target::(); - pw.set_biguint_target(&n_target.value, &n.to_canonical_biguint()); - - let res_target = fixed_base_curve_mul_circuit(&mut builder, g, &n_target); - builder.curve_assert_valid(&res_target); - - builder.connect_affine_point(&res_target, &res_expected); - - dbg!(builder.num_gates()); - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - - data.verify(proof) - } -} diff --git a/ecdsa/src/gadgets/curve_msm.rs b/ecdsa/src/gadgets/curve_msm.rs deleted file mode 100644 index 7bb4a6cc..00000000 --- a/ecdsa/src/gadgets/curve_msm.rs +++ /dev/null @@ -1,138 +0,0 @@ -use alloc::vec; - -use num::BigUint; -use plonky2::field::extension::Extendable; -use plonky2::field::types::Field; -use plonky2::hash::hash_types::RichField; -use plonky2::hash::keccak::KeccakHash; -use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2::plonk::config::{GenericHashOut, Hasher}; - -use crate::curve::curve_types::{Curve, CurveScalar}; -use crate::gadgets::curve::{AffinePointTarget, CircuitBuilderCurve}; -use crate::gadgets::curve_windowed_mul::CircuitBuilderWindowedMul; -use crate::gadgets::nonnative::NonNativeTarget; -use crate::gadgets::split_nonnative::CircuitBuilderSplit; - -/// Computes `n*p + m*q` using windowed MSM, with a 2-bit window. -/// See Algorithm 9.23 in Handbook of Elliptic and Hyperelliptic Curve Cryptography for a -/// description. -/// Note: Doesn't work if `p == q`. -pub fn curve_msm_circuit, const D: usize>( - builder: &mut CircuitBuilder, - p: &AffinePointTarget, - q: &AffinePointTarget, - n: &NonNativeTarget, - m: &NonNativeTarget, -) -> AffinePointTarget { - let limbs_n = builder.split_nonnative_to_2_bit_limbs(n); - let limbs_m = builder.split_nonnative_to_2_bit_limbs(m); - assert_eq!(limbs_n.len(), limbs_m.len()); - let num_limbs = limbs_n.len(); - - let hash_0 = KeccakHash::<32>::hash_no_pad(&[F::ZERO]); - let hash_0_scalar = C::ScalarField::from_noncanonical_biguint(BigUint::from_bytes_le( - &GenericHashOut::::to_bytes(&hash_0), - )); - let rando = (CurveScalar(hash_0_scalar) * C::GENERATOR_PROJECTIVE).to_affine(); - let rando_t = builder.constant_affine_point(rando); - let neg_rando = builder.constant_affine_point(-rando); - - // Precomputes `precomputation[i + 4*j] = i*p + j*q` for `i,j=0..4`. - let mut precomputation = vec![p.clone(); 16]; - let mut cur_p = rando_t.clone(); - let mut cur_q = rando_t.clone(); - for i in 0..4 { - precomputation[i] = cur_p.clone(); - precomputation[4 * i] = cur_q.clone(); - cur_p = builder.curve_add(&cur_p, p); - cur_q = builder.curve_add(&cur_q, q); - } - for i in 1..4 { - precomputation[i] = builder.curve_add(&precomputation[i], &neg_rando); - precomputation[4 * i] = builder.curve_add(&precomputation[4 * i], &neg_rando); - } - for i in 1..4 { - for j in 1..4 { - precomputation[i + 4 * j] = - builder.curve_add(&precomputation[i], &precomputation[4 * j]); - } - } - - let four = builder.constant(F::from_canonical_usize(4)); - - let zero = builder.zero(); - let mut result = rando_t; - for (limb_n, limb_m) in limbs_n.into_iter().zip(limbs_m).rev() { - result = builder.curve_repeated_double(&result, 2); - let index = builder.mul_add(four, limb_m, limb_n); - let r = builder.random_access_curve_points(index, precomputation.clone()); - let is_zero = builder.is_equal(index, zero); - let should_add = builder.not(is_zero); - result = builder.curve_conditional_add(&result, &r, should_add); - } - let starting_point_multiplied = (0..2 * num_limbs).fold(rando, |acc, _| acc.double()); - let to_add = builder.constant_affine_point(-starting_point_multiplied); - result = builder.curve_add(&result, &to_add); - - result -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - use plonky2::field::secp256k1_scalar::Secp256K1Scalar; - use plonky2::field::types::Sample; - use plonky2::iop::witness::PartialWitness; - use plonky2::plonk::circuit_builder::CircuitBuilder; - use plonky2::plonk::circuit_data::CircuitConfig; - use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - - use crate::curve::curve_types::{Curve, CurveScalar}; - use crate::curve::secp256k1::Secp256K1; - use crate::gadgets::curve::CircuitBuilderCurve; - use crate::gadgets::curve_msm::curve_msm_circuit; - use crate::gadgets::nonnative::CircuitBuilderNonNative; - - #[test] - #[ignore] - fn test_curve_msm() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_ecc_config(); - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let p = - (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); - let q = - (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); - let n = Secp256K1Scalar::rand(); - let m = Secp256K1Scalar::rand(); - - let res = - (CurveScalar(n) * p.to_projective() + CurveScalar(m) * q.to_projective()).to_affine(); - let res_expected = builder.constant_affine_point(res); - builder.curve_assert_valid(&res_expected); - - let p_target = builder.constant_affine_point(p); - let q_target = builder.constant_affine_point(q); - let n_target = builder.constant_nonnative(n); - let m_target = builder.constant_nonnative(m); - - let res_target = - curve_msm_circuit(&mut builder, &p_target, &q_target, &n_target, &m_target); - builder.curve_assert_valid(&res_target); - - builder.connect_affine_point(&res_target, &res_expected); - - dbg!(builder.num_gates()); - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - - data.verify(proof) - } -} diff --git a/ecdsa/src/gadgets/curve_windowed_mul.rs b/ecdsa/src/gadgets/curve_windowed_mul.rs deleted file mode 100644 index 39fad17c..00000000 --- a/ecdsa/src/gadgets/curve_windowed_mul.rs +++ /dev/null @@ -1,254 +0,0 @@ -use alloc::vec; -use alloc::vec::Vec; -use core::marker::PhantomData; - -use num::BigUint; -use plonky2::field::extension::Extendable; -use plonky2::field::types::{Field, Sample}; -use plonky2::hash::hash_types::RichField; -use plonky2::hash::keccak::KeccakHash; -use plonky2::iop::target::{BoolTarget, Target}; -use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2::plonk::config::{GenericHashOut, Hasher}; -use plonky2_u32::gadgets::arithmetic_u32::{CircuitBuilderU32, U32Target}; - -use crate::curve::curve_types::{Curve, CurveScalar}; -use crate::gadgets::biguint::BigUintTarget; -use crate::gadgets::curve::{AffinePointTarget, CircuitBuilderCurve}; -use crate::gadgets::nonnative::{CircuitBuilderNonNative, NonNativeTarget}; -use crate::gadgets::split_nonnative::CircuitBuilderSplit; - -const WINDOW_SIZE: usize = 4; - -pub trait CircuitBuilderWindowedMul, const D: usize> { - fn precompute_window( - &mut self, - p: &AffinePointTarget, - ) -> Vec>; - - fn random_access_curve_points( - &mut self, - access_index: Target, - v: Vec>, - ) -> AffinePointTarget; - - fn if_affine_point( - &mut self, - b: BoolTarget, - p1: &AffinePointTarget, - p2: &AffinePointTarget, - ) -> AffinePointTarget; - - fn curve_scalar_mul_windowed( - &mut self, - p: &AffinePointTarget, - n: &NonNativeTarget, - ) -> AffinePointTarget; -} - -impl, const D: usize> CircuitBuilderWindowedMul - for CircuitBuilder -{ - fn precompute_window( - &mut self, - p: &AffinePointTarget, - ) -> Vec> { - let g = (CurveScalar(C::ScalarField::rand()) * C::GENERATOR_PROJECTIVE).to_affine(); - let neg = { - let mut neg = g; - neg.y = -neg.y; - self.constant_affine_point(neg) - }; - - let mut multiples = vec![self.constant_affine_point(g)]; - for i in 1..1 << WINDOW_SIZE { - multiples.push(self.curve_add(p, &multiples[i - 1])); - } - for i in 1..1 << WINDOW_SIZE { - multiples[i] = self.curve_add(&neg, &multiples[i]); - } - multiples - } - - fn random_access_curve_points( - &mut self, - access_index: Target, - v: Vec>, - ) -> AffinePointTarget { - let num_limbs = C::BaseField::BITS / 32; - let zero = self.zero_u32(); - let x_limbs: Vec> = (0..num_limbs) - .map(|i| { - v.iter() - .map(|p| p.x.value.limbs.get(i).unwrap_or(&zero).0) - .collect() - }) - .collect(); - let y_limbs: Vec> = (0..num_limbs) - .map(|i| { - v.iter() - .map(|p| p.y.value.limbs.get(i).unwrap_or(&zero).0) - .collect() - }) - .collect(); - - let selected_x_limbs: Vec<_> = x_limbs - .iter() - .map(|limbs| U32Target(self.random_access(access_index, limbs.clone()))) - .collect(); - let selected_y_limbs: Vec<_> = y_limbs - .iter() - .map(|limbs| U32Target(self.random_access(access_index, limbs.clone()))) - .collect(); - - let x = NonNativeTarget { - value: BigUintTarget { - limbs: selected_x_limbs, - }, - _phantom: PhantomData, - }; - let y = NonNativeTarget { - value: BigUintTarget { - limbs: selected_y_limbs, - }, - _phantom: PhantomData, - }; - AffinePointTarget { x, y } - } - - fn if_affine_point( - &mut self, - b: BoolTarget, - p1: &AffinePointTarget, - p2: &AffinePointTarget, - ) -> AffinePointTarget { - let new_x = self.if_nonnative(b, &p1.x, &p2.x); - let new_y = self.if_nonnative(b, &p1.y, &p2.y); - AffinePointTarget { x: new_x, y: new_y } - } - - fn curve_scalar_mul_windowed( - &mut self, - p: &AffinePointTarget, - n: &NonNativeTarget, - ) -> AffinePointTarget { - let hash_0 = KeccakHash::<25>::hash_no_pad(&[F::ZERO]); - let hash_0_scalar = C::ScalarField::from_noncanonical_biguint(BigUint::from_bytes_le( - &GenericHashOut::::to_bytes(&hash_0), - )); - let starting_point = CurveScalar(hash_0_scalar) * C::GENERATOR_PROJECTIVE; - let starting_point_multiplied = { - let mut cur = starting_point; - for _ in 0..C::ScalarField::BITS { - cur = cur.double(); - } - cur - }; - - let mut result = self.constant_affine_point(starting_point.to_affine()); - - let precomputation = self.precompute_window(p); - let zero = self.zero(); - - let windows = self.split_nonnative_to_4_bit_limbs(n); - for i in (0..windows.len()).rev() { - result = self.curve_repeated_double(&result, WINDOW_SIZE); - let window = windows[i]; - - let to_add = self.random_access_curve_points(window, precomputation.clone()); - let is_zero = self.is_equal(window, zero); - let should_add = self.not(is_zero); - result = self.curve_conditional_add(&result, &to_add, should_add); - } - - let to_subtract = self.constant_affine_point(starting_point_multiplied.to_affine()); - let to_add = self.curve_neg(&to_subtract); - result = self.curve_add(&result, &to_add); - - result - } -} - -#[cfg(test)] -mod tests { - use core::ops::Neg; - - use anyhow::Result; - use plonky2::field::secp256k1_scalar::Secp256K1Scalar; - use plonky2::iop::witness::PartialWitness; - use plonky2::plonk::circuit_data::CircuitConfig; - use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use rand::rngs::OsRng; - use rand::Rng; - - use super::*; - use crate::curve::secp256k1::Secp256K1; - - #[test] - fn test_random_access_curve_points() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_ecc_config(); - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let num_points = 16; - let points: Vec<_> = (0..num_points) - .map(|_| { - let g = (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE) - .to_affine(); - builder.constant_affine_point(g) - }) - .collect(); - - let mut rng = OsRng; - let access_index = rng.gen::() % num_points; - - let access_index_target = builder.constant(F::from_canonical_usize(access_index)); - let selected = builder.random_access_curve_points(access_index_target, points.clone()); - let expected = points[access_index].clone(); - builder.connect_affine_point(&selected, &expected); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - - data.verify(proof) - } - - #[test] - #[ignore] - fn test_curve_windowed_mul() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_ecc_config(); - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let g = - (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); - let five = Secp256K1Scalar::from_canonical_usize(5); - let neg_five = five.neg(); - let neg_five_scalar = CurveScalar::(neg_five); - let neg_five_g = (neg_five_scalar * g.to_projective()).to_affine(); - let neg_five_g_expected = builder.constant_affine_point(neg_five_g); - builder.curve_assert_valid(&neg_five_g_expected); - - let g_target = builder.constant_affine_point(g); - let neg_five_target = builder.constant_nonnative(neg_five); - let neg_five_g_actual = builder.curve_scalar_mul_windowed(&g_target, &neg_five_target); - builder.curve_assert_valid(&neg_five_g_actual); - - builder.connect_affine_point(&neg_five_g_expected, &neg_five_g_actual); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - - data.verify(proof) - } -} diff --git a/ecdsa/src/gadgets/ecdsa.rs b/ecdsa/src/gadgets/ecdsa.rs deleted file mode 100644 index 657ec492..00000000 --- a/ecdsa/src/gadgets/ecdsa.rs +++ /dev/null @@ -1,111 +0,0 @@ -use core::marker::PhantomData; - -use plonky2::field::extension::Extendable; -use plonky2::field::secp256k1_scalar::Secp256K1Scalar; -use plonky2::hash::hash_types::RichField; -use plonky2::plonk::circuit_builder::CircuitBuilder; - -use crate::curve::curve_types::Curve; -use crate::curve::secp256k1::Secp256K1; -use crate::gadgets::curve::{AffinePointTarget, CircuitBuilderCurve}; -use crate::gadgets::curve_fixed_base::fixed_base_curve_mul_circuit; -use crate::gadgets::glv::CircuitBuilderGlv; -use crate::gadgets::nonnative::{CircuitBuilderNonNative, NonNativeTarget}; - -#[derive(Clone, Debug)] -pub struct ECDSASecretKeyTarget(pub NonNativeTarget); - -#[derive(Clone, Debug)] -pub struct ECDSAPublicKeyTarget(pub AffinePointTarget); - -#[derive(Clone, Debug)] -pub struct ECDSASignatureTarget { - pub r: NonNativeTarget, - pub s: NonNativeTarget, -} - -pub fn verify_message_circuit, const D: usize>( - builder: &mut CircuitBuilder, - msg: NonNativeTarget, - sig: ECDSASignatureTarget, - pk: ECDSAPublicKeyTarget, -) { - let ECDSASignatureTarget { r, s } = sig; - - builder.curve_assert_valid(&pk.0); - - let c = builder.inv_nonnative(&s); - let u1 = builder.mul_nonnative(&msg, &c); - let u2 = builder.mul_nonnative(&r, &c); - - let point1 = fixed_base_curve_mul_circuit(builder, Secp256K1::GENERATOR_AFFINE, &u1); - let point2 = builder.glv_mul(&pk.0, &u2); - let point = builder.curve_add(&point1, &point2); - - let x = NonNativeTarget:: { - value: point.x.value, - _phantom: PhantomData, - }; - builder.connect_nonnative(&r, &x); -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - use plonky2::field::types::Sample; - use plonky2::iop::witness::PartialWitness; - use plonky2::plonk::circuit_data::CircuitConfig; - use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - - use super::*; - use crate::curve::curve_types::CurveScalar; - use crate::curve::ecdsa::{sign_message, ECDSAPublicKey, ECDSASecretKey, ECDSASignature}; - - fn test_ecdsa_circuit_with_config(config: CircuitConfig) -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - type Curve = Secp256K1; - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let msg = Secp256K1Scalar::rand(); - let msg_target = builder.constant_nonnative(msg); - - let sk = ECDSASecretKey::(Secp256K1Scalar::rand()); - let pk = ECDSAPublicKey((CurveScalar(sk.0) * Curve::GENERATOR_PROJECTIVE).to_affine()); - - let pk_target = ECDSAPublicKeyTarget(builder.constant_affine_point(pk.0)); - - let sig = sign_message(msg, sk); - - let ECDSASignature { r, s } = sig; - let r_target = builder.constant_nonnative(r); - let s_target = builder.constant_nonnative(s); - let sig_target = ECDSASignatureTarget { - r: r_target, - s: s_target, - }; - - verify_message_circuit(&mut builder, msg_target, sig_target, pk_target); - - dbg!(builder.num_gates()); - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - data.verify(proof) - } - - #[test] - #[ignore] - fn test_ecdsa_circuit_narrow() -> Result<()> { - test_ecdsa_circuit_with_config(CircuitConfig::standard_ecc_config()) - } - - #[test] - #[ignore] - fn test_ecdsa_circuit_wide() -> Result<()> { - test_ecdsa_circuit_with_config(CircuitConfig::wide_ecc_config()) - } -} diff --git a/ecdsa/src/gadgets/glv.rs b/ecdsa/src/gadgets/glv.rs deleted file mode 100644 index 8ffa9c8e..00000000 --- a/ecdsa/src/gadgets/glv.rs +++ /dev/null @@ -1,180 +0,0 @@ -use alloc::vec::Vec; -use core::marker::PhantomData; - -use plonky2::field::extension::Extendable; -use plonky2::field::secp256k1_base::Secp256K1Base; -use plonky2::field::secp256k1_scalar::Secp256K1Scalar; -use plonky2::field::types::{Field, PrimeField}; -use plonky2::hash::hash_types::RichField; -use plonky2::iop::generator::{GeneratedValues, SimpleGenerator}; -use plonky2::iop::target::{BoolTarget, Target}; -use plonky2::iop::witness::{PartitionWitness, WitnessWrite}; -use plonky2::plonk::circuit_builder::CircuitBuilder; - -use crate::curve::glv::{decompose_secp256k1_scalar, GLV_BETA, GLV_S}; -use crate::curve::secp256k1::Secp256K1; -use crate::gadgets::biguint::{GeneratedValuesBigUint, WitnessBigUint}; -use crate::gadgets::curve::{AffinePointTarget, CircuitBuilderCurve}; -use crate::gadgets::curve_msm::curve_msm_circuit; -use crate::gadgets::nonnative::{CircuitBuilderNonNative, NonNativeTarget}; - -pub trait CircuitBuilderGlv, const D: usize> { - fn secp256k1_glv_beta(&mut self) -> NonNativeTarget; - - fn decompose_secp256k1_scalar( - &mut self, - k: &NonNativeTarget, - ) -> ( - NonNativeTarget, - NonNativeTarget, - BoolTarget, - BoolTarget, - ); - - fn glv_mul( - &mut self, - p: &AffinePointTarget, - k: &NonNativeTarget, - ) -> AffinePointTarget; -} - -impl, const D: usize> CircuitBuilderGlv - for CircuitBuilder -{ - fn secp256k1_glv_beta(&mut self) -> NonNativeTarget { - self.constant_nonnative(GLV_BETA) - } - - fn decompose_secp256k1_scalar( - &mut self, - k: &NonNativeTarget, - ) -> ( - NonNativeTarget, - NonNativeTarget, - BoolTarget, - BoolTarget, - ) { - let k1 = self.add_virtual_nonnative_target_sized::(4); - let k2 = self.add_virtual_nonnative_target_sized::(4); - let k1_neg = self.add_virtual_bool_target_unsafe(); - let k2_neg = self.add_virtual_bool_target_unsafe(); - - self.add_simple_generator(GLVDecompositionGenerator:: { - k: k.clone(), - k1: k1.clone(), - k2: k2.clone(), - k1_neg, - k2_neg, - _phantom: PhantomData, - }); - - // Check that `k1_raw + GLV_S * k2_raw == k`. - let k1_raw = self.nonnative_conditional_neg(&k1, k1_neg); - let k2_raw = self.nonnative_conditional_neg(&k2, k2_neg); - let s = self.constant_nonnative(GLV_S); - let mut should_be_k = self.mul_nonnative(&s, &k2_raw); - should_be_k = self.add_nonnative(&should_be_k, &k1_raw); - self.connect_nonnative(&should_be_k, k); - - (k1, k2, k1_neg, k2_neg) - } - - fn glv_mul( - &mut self, - p: &AffinePointTarget, - k: &NonNativeTarget, - ) -> AffinePointTarget { - let (k1, k2, k1_neg, k2_neg) = self.decompose_secp256k1_scalar(k); - - let beta = self.secp256k1_glv_beta(); - let beta_px = self.mul_nonnative(&beta, &p.x); - let sp = AffinePointTarget:: { - x: beta_px, - y: p.y.clone(), - }; - - let p_neg = self.curve_conditional_neg(p, k1_neg); - let sp_neg = self.curve_conditional_neg(&sp, k2_neg); - curve_msm_circuit(self, &p_neg, &sp_neg, &k1, &k2) - } -} - -#[derive(Debug)] -struct GLVDecompositionGenerator, const D: usize> { - k: NonNativeTarget, - k1: NonNativeTarget, - k2: NonNativeTarget, - k1_neg: BoolTarget, - k2_neg: BoolTarget, - _phantom: PhantomData, -} - -impl, const D: usize> SimpleGenerator - for GLVDecompositionGenerator -{ - fn dependencies(&self) -> Vec { - self.k.value.limbs.iter().map(|l| l.0).collect() - } - - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let k = Secp256K1Scalar::from_noncanonical_biguint( - witness.get_biguint_target(self.k.value.clone()), - ); - - let (k1, k2, k1_neg, k2_neg) = decompose_secp256k1_scalar(k); - - out_buffer.set_biguint_target(&self.k1.value, &k1.to_canonical_biguint()); - out_buffer.set_biguint_target(&self.k2.value, &k2.to_canonical_biguint()); - out_buffer.set_bool_target(self.k1_neg, k1_neg); - out_buffer.set_bool_target(self.k2_neg, k2_neg); - } -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - use plonky2::field::secp256k1_scalar::Secp256K1Scalar; - use plonky2::field::types::Sample; - use plonky2::iop::witness::PartialWitness; - use plonky2::plonk::circuit_builder::CircuitBuilder; - use plonky2::plonk::circuit_data::CircuitConfig; - use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - - use crate::curve::curve_types::{Curve, CurveScalar}; - use crate::curve::glv::glv_mul; - use crate::curve::secp256k1::Secp256K1; - use crate::gadgets::curve::CircuitBuilderCurve; - use crate::gadgets::glv::CircuitBuilderGlv; - use crate::gadgets::nonnative::CircuitBuilderNonNative; - - #[test] - #[ignore] - fn test_glv_gadget() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_ecc_config(); - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let rando = - (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); - let randot = builder.constant_affine_point(rando); - - let scalar = Secp256K1Scalar::rand(); - let scalar_target = builder.constant_nonnative(scalar); - - let rando_glv_scalar = glv_mul(rando.to_projective(), scalar); - let expected = builder.constant_affine_point(rando_glv_scalar.to_affine()); - let actual = builder.glv_mul(&randot, &scalar_target); - builder.connect_affine_point(&expected, &actual); - - dbg!(builder.num_gates()); - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - - data.verify(proof) - } -} diff --git a/ecdsa/src/gadgets/mod.rs b/ecdsa/src/gadgets/mod.rs deleted file mode 100644 index 35b10100..00000000 --- a/ecdsa/src/gadgets/mod.rs +++ /dev/null @@ -1,9 +0,0 @@ -pub mod biguint; -pub mod curve; -pub mod curve_fixed_base; -pub mod curve_msm; -pub mod curve_windowed_mul; -pub mod ecdsa; -pub mod glv; -pub mod nonnative; -pub mod split_nonnative; diff --git a/ecdsa/src/gadgets/nonnative.rs b/ecdsa/src/gadgets/nonnative.rs deleted file mode 100644 index f1c8f03b..00000000 --- a/ecdsa/src/gadgets/nonnative.rs +++ /dev/null @@ -1,826 +0,0 @@ -use alloc::vec; -use alloc::vec::Vec; -use core::marker::PhantomData; - -use num::{BigUint, Integer, One, Zero}; -use plonky2::field::extension::Extendable; -use plonky2::field::types::{Field, PrimeField}; -use plonky2::hash::hash_types::RichField; -use plonky2::iop::generator::{GeneratedValues, SimpleGenerator}; -use plonky2::iop::target::{BoolTarget, Target}; -use plonky2::iop::witness::{PartitionWitness, WitnessWrite}; -use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2::util::ceil_div_usize; -use plonky2_u32::gadgets::arithmetic_u32::{CircuitBuilderU32, U32Target}; -use plonky2_u32::gadgets::range_check::range_check_u32_circuit; -use plonky2_u32::witness::GeneratedValuesU32; - -use crate::gadgets::biguint::{ - BigUintTarget, CircuitBuilderBiguint, GeneratedValuesBigUint, WitnessBigUint, -}; - -#[derive(Clone, Debug)] -pub struct NonNativeTarget { - pub(crate) value: BigUintTarget, - pub(crate) _phantom: PhantomData, -} - -pub trait CircuitBuilderNonNative, const D: usize> { - fn num_nonnative_limbs() -> usize { - ceil_div_usize(FF::BITS, 32) - } - - fn biguint_to_nonnative(&mut self, x: &BigUintTarget) -> NonNativeTarget; - - fn nonnative_to_canonical_biguint( - &mut self, - x: &NonNativeTarget, - ) -> BigUintTarget; - - fn constant_nonnative(&mut self, x: FF) -> NonNativeTarget; - - fn zero_nonnative(&mut self) -> NonNativeTarget; - - // Assert that two NonNativeTarget's, both assumed to be in reduced form, are equal. - fn connect_nonnative( - &mut self, - lhs: &NonNativeTarget, - rhs: &NonNativeTarget, - ); - - fn add_virtual_nonnative_target(&mut self) -> NonNativeTarget; - - fn add_virtual_nonnative_target_sized( - &mut self, - num_limbs: usize, - ) -> NonNativeTarget; - - fn add_nonnative( - &mut self, - a: &NonNativeTarget, - b: &NonNativeTarget, - ) -> NonNativeTarget; - - fn mul_nonnative_by_bool( - &mut self, - a: &NonNativeTarget, - b: BoolTarget, - ) -> NonNativeTarget; - - fn if_nonnative( - &mut self, - b: BoolTarget, - x: &NonNativeTarget, - y: &NonNativeTarget, - ) -> NonNativeTarget; - - fn add_many_nonnative( - &mut self, - to_add: &[NonNativeTarget], - ) -> NonNativeTarget; - - // Subtract two `NonNativeTarget`s. - fn sub_nonnative( - &mut self, - a: &NonNativeTarget, - b: &NonNativeTarget, - ) -> NonNativeTarget; - - fn mul_nonnative( - &mut self, - a: &NonNativeTarget, - b: &NonNativeTarget, - ) -> NonNativeTarget; - - fn mul_many_nonnative( - &mut self, - to_mul: &[NonNativeTarget], - ) -> NonNativeTarget; - - fn neg_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget; - - fn inv_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget; - - /// Returns `x % |FF|` as a `NonNativeTarget`. - fn reduce(&mut self, x: &BigUintTarget) -> NonNativeTarget; - - fn reduce_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget; - - fn bool_to_nonnative(&mut self, b: &BoolTarget) -> NonNativeTarget; - - // Split a nonnative field element to bits. - fn split_nonnative_to_bits(&mut self, x: &NonNativeTarget) -> Vec; - - fn nonnative_conditional_neg( - &mut self, - x: &NonNativeTarget, - b: BoolTarget, - ) -> NonNativeTarget; -} - -impl, const D: usize> CircuitBuilderNonNative - for CircuitBuilder -{ - fn num_nonnative_limbs() -> usize { - ceil_div_usize(FF::BITS, 32) - } - - fn biguint_to_nonnative(&mut self, x: &BigUintTarget) -> NonNativeTarget { - NonNativeTarget { - value: x.clone(), - _phantom: PhantomData, - } - } - - fn nonnative_to_canonical_biguint( - &mut self, - x: &NonNativeTarget, - ) -> BigUintTarget { - x.value.clone() - } - - fn constant_nonnative(&mut self, x: FF) -> NonNativeTarget { - let x_biguint = self.constant_biguint(&x.to_canonical_biguint()); - self.biguint_to_nonnative(&x_biguint) - } - - fn zero_nonnative(&mut self) -> NonNativeTarget { - self.constant_nonnative(FF::ZERO) - } - - // Assert that two NonNativeTarget's, both assumed to be in reduced form, are equal. - fn connect_nonnative( - &mut self, - lhs: &NonNativeTarget, - rhs: &NonNativeTarget, - ) { - self.connect_biguint(&lhs.value, &rhs.value); - } - - fn add_virtual_nonnative_target(&mut self) -> NonNativeTarget { - let num_limbs = Self::num_nonnative_limbs::(); - let value = self.add_virtual_biguint_target(num_limbs); - - NonNativeTarget { - value, - _phantom: PhantomData, - } - } - - fn add_virtual_nonnative_target_sized( - &mut self, - num_limbs: usize, - ) -> NonNativeTarget { - let value = self.add_virtual_biguint_target(num_limbs); - - NonNativeTarget { - value, - _phantom: PhantomData, - } - } - - fn add_nonnative( - &mut self, - a: &NonNativeTarget, - b: &NonNativeTarget, - ) -> NonNativeTarget { - let sum = self.add_virtual_nonnative_target::(); - let overflow = self.add_virtual_bool_target_unsafe(); - - self.add_simple_generator(NonNativeAdditionGenerator:: { - a: a.clone(), - b: b.clone(), - sum: sum.clone(), - overflow, - _phantom: PhantomData, - }); - - let sum_expected = self.add_biguint(&a.value, &b.value); - - let modulus = self.constant_biguint(&FF::order()); - let mod_times_overflow = self.mul_biguint_by_bool(&modulus, overflow); - let sum_actual = self.add_biguint(&sum.value, &mod_times_overflow); - self.connect_biguint(&sum_expected, &sum_actual); - - // Range-check result. - // TODO: can potentially leave unreduced until necessary (e.g. when connecting values). - let cmp = self.cmp_biguint(&sum.value, &modulus); - let one = self.one(); - self.connect(cmp.target, one); - - sum - } - - fn mul_nonnative_by_bool( - &mut self, - a: &NonNativeTarget, - b: BoolTarget, - ) -> NonNativeTarget { - NonNativeTarget { - value: self.mul_biguint_by_bool(&a.value, b), - _phantom: PhantomData, - } - } - - fn if_nonnative( - &mut self, - b: BoolTarget, - x: &NonNativeTarget, - y: &NonNativeTarget, - ) -> NonNativeTarget { - let not_b = self.not(b); - let maybe_x = self.mul_nonnative_by_bool(x, b); - let maybe_y = self.mul_nonnative_by_bool(y, not_b); - self.add_nonnative(&maybe_x, &maybe_y) - } - - fn add_many_nonnative( - &mut self, - to_add: &[NonNativeTarget], - ) -> NonNativeTarget { - if to_add.len() == 1 { - return to_add[0].clone(); - } - - let sum = self.add_virtual_nonnative_target::(); - let overflow = self.add_virtual_u32_target(); - let summands = to_add.to_vec(); - - self.add_simple_generator(NonNativeMultipleAddsGenerator:: { - summands: summands.clone(), - sum: sum.clone(), - overflow, - _phantom: PhantomData, - }); - - range_check_u32_circuit(self, sum.value.limbs.clone()); - range_check_u32_circuit(self, vec![overflow]); - - let sum_expected = summands - .iter() - .fold(self.zero_biguint(), |a, b| self.add_biguint(&a, &b.value)); - - let modulus = self.constant_biguint(&FF::order()); - let overflow_biguint = BigUintTarget { - limbs: vec![overflow], - }; - let mod_times_overflow = self.mul_biguint(&modulus, &overflow_biguint); - let sum_actual = self.add_biguint(&sum.value, &mod_times_overflow); - self.connect_biguint(&sum_expected, &sum_actual); - - // Range-check result. - // TODO: can potentially leave unreduced until necessary (e.g. when connecting values). - let cmp = self.cmp_biguint(&sum.value, &modulus); - let one = self.one(); - self.connect(cmp.target, one); - - sum - } - - // Subtract two `NonNativeTarget`s. - fn sub_nonnative( - &mut self, - a: &NonNativeTarget, - b: &NonNativeTarget, - ) -> NonNativeTarget { - let diff = self.add_virtual_nonnative_target::(); - let overflow = self.add_virtual_bool_target_unsafe(); - - self.add_simple_generator(NonNativeSubtractionGenerator:: { - a: a.clone(), - b: b.clone(), - diff: diff.clone(), - overflow, - _phantom: PhantomData, - }); - - range_check_u32_circuit(self, diff.value.limbs.clone()); - self.assert_bool(overflow); - - let diff_plus_b = self.add_biguint(&diff.value, &b.value); - let modulus = self.constant_biguint(&FF::order()); - let mod_times_overflow = self.mul_biguint_by_bool(&modulus, overflow); - let diff_plus_b_reduced = self.sub_biguint(&diff_plus_b, &mod_times_overflow); - self.connect_biguint(&a.value, &diff_plus_b_reduced); - - diff - } - - fn mul_nonnative( - &mut self, - a: &NonNativeTarget, - b: &NonNativeTarget, - ) -> NonNativeTarget { - let prod = self.add_virtual_nonnative_target::(); - let modulus = self.constant_biguint(&FF::order()); - let overflow = self.add_virtual_biguint_target( - a.value.num_limbs() + b.value.num_limbs() - modulus.num_limbs(), - ); - - self.add_simple_generator(NonNativeMultiplicationGenerator:: { - a: a.clone(), - b: b.clone(), - prod: prod.clone(), - overflow: overflow.clone(), - _phantom: PhantomData, - }); - - range_check_u32_circuit(self, prod.value.limbs.clone()); - range_check_u32_circuit(self, overflow.limbs.clone()); - - let prod_expected = self.mul_biguint(&a.value, &b.value); - - let mod_times_overflow = self.mul_biguint(&modulus, &overflow); - let prod_actual = self.add_biguint(&prod.value, &mod_times_overflow); - self.connect_biguint(&prod_expected, &prod_actual); - - prod - } - - fn mul_many_nonnative( - &mut self, - to_mul: &[NonNativeTarget], - ) -> NonNativeTarget { - if to_mul.len() == 1 { - return to_mul[0].clone(); - } - - let mut accumulator = self.mul_nonnative(&to_mul[0], &to_mul[1]); - for t in to_mul.iter().skip(2) { - accumulator = self.mul_nonnative(&accumulator, t); - } - accumulator - } - - fn neg_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { - let zero_target = self.constant_biguint(&BigUint::zero()); - let zero_ff = self.biguint_to_nonnative(&zero_target); - - self.sub_nonnative(&zero_ff, x) - } - - fn inv_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { - let num_limbs = x.value.num_limbs(); - let inv_biguint = self.add_virtual_biguint_target(num_limbs); - let div = self.add_virtual_biguint_target(num_limbs); - - self.add_simple_generator(NonNativeInverseGenerator:: { - x: x.clone(), - inv: inv_biguint.clone(), - div: div.clone(), - _phantom: PhantomData, - }); - - let product = self.mul_biguint(&x.value, &inv_biguint); - - let modulus = self.constant_biguint(&FF::order()); - let mod_times_div = self.mul_biguint(&modulus, &div); - let one = self.constant_biguint(&BigUint::one()); - let expected_product = self.add_biguint(&mod_times_div, &one); - self.connect_biguint(&product, &expected_product); - - NonNativeTarget:: { - value: inv_biguint, - _phantom: PhantomData, - } - } - - /// Returns `x % |FF|` as a `NonNativeTarget`. - fn reduce(&mut self, x: &BigUintTarget) -> NonNativeTarget { - let modulus = FF::order(); - let order_target = self.constant_biguint(&modulus); - let value = self.rem_biguint(x, &order_target); - - NonNativeTarget { - value, - _phantom: PhantomData, - } - } - - fn reduce_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { - let x_biguint = self.nonnative_to_canonical_biguint(x); - self.reduce(&x_biguint) - } - - fn bool_to_nonnative(&mut self, b: &BoolTarget) -> NonNativeTarget { - let limbs = vec![U32Target(b.target)]; - let value = BigUintTarget { limbs }; - - NonNativeTarget { - value, - _phantom: PhantomData, - } - } - - // Split a nonnative field element to bits. - fn split_nonnative_to_bits(&mut self, x: &NonNativeTarget) -> Vec { - let num_limbs = x.value.num_limbs(); - let mut result = Vec::with_capacity(num_limbs * 32); - - for i in 0..num_limbs { - let limb = x.value.get_limb(i); - let bit_targets = self.split_le_base::<2>(limb.0, 32); - let mut bits: Vec<_> = bit_targets - .iter() - .map(|&t| BoolTarget::new_unsafe(t)) - .collect(); - - result.append(&mut bits); - } - - result - } - - fn nonnative_conditional_neg( - &mut self, - x: &NonNativeTarget, - b: BoolTarget, - ) -> NonNativeTarget { - let not_b = self.not(b); - let neg = self.neg_nonnative(x); - let x_if_true = self.mul_nonnative_by_bool(&neg, b); - let x_if_false = self.mul_nonnative_by_bool(x, not_b); - - self.add_nonnative(&x_if_true, &x_if_false) - } -} - -#[derive(Debug)] -struct NonNativeAdditionGenerator, const D: usize, FF: PrimeField> { - a: NonNativeTarget, - b: NonNativeTarget, - sum: NonNativeTarget, - overflow: BoolTarget, - _phantom: PhantomData, -} - -impl, const D: usize, FF: PrimeField> SimpleGenerator - for NonNativeAdditionGenerator -{ - fn dependencies(&self) -> Vec { - self.a - .value - .limbs - .iter() - .cloned() - .chain(self.b.value.limbs.clone()) - .map(|l| l.0) - .collect() - } - - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let a = FF::from_noncanonical_biguint(witness.get_biguint_target(self.a.value.clone())); - let b = FF::from_noncanonical_biguint(witness.get_biguint_target(self.b.value.clone())); - let a_biguint = a.to_canonical_biguint(); - let b_biguint = b.to_canonical_biguint(); - let sum_biguint = a_biguint + b_biguint; - let modulus = FF::order(); - let (overflow, sum_reduced) = if sum_biguint > modulus { - (true, sum_biguint - modulus) - } else { - (false, sum_biguint) - }; - - out_buffer.set_biguint_target(&self.sum.value, &sum_reduced); - out_buffer.set_bool_target(self.overflow, overflow); - } -} - -#[derive(Debug)] -struct NonNativeMultipleAddsGenerator, const D: usize, FF: PrimeField> -{ - summands: Vec>, - sum: NonNativeTarget, - overflow: U32Target, - _phantom: PhantomData, -} - -impl, const D: usize, FF: PrimeField> SimpleGenerator - for NonNativeMultipleAddsGenerator -{ - fn dependencies(&self) -> Vec { - self.summands - .iter() - .flat_map(|summand| summand.value.limbs.iter().map(|limb| limb.0)) - .collect() - } - - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let summands: Vec<_> = self - .summands - .iter() - .map(|summand| { - FF::from_noncanonical_biguint(witness.get_biguint_target(summand.value.clone())) - }) - .collect(); - let summand_biguints: Vec<_> = summands - .iter() - .map(|summand| summand.to_canonical_biguint()) - .collect(); - - let sum_biguint = summand_biguints - .iter() - .fold(BigUint::zero(), |a, b| a + b.clone()); - - let modulus = FF::order(); - let (overflow_biguint, sum_reduced) = sum_biguint.div_rem(&modulus); - let overflow = overflow_biguint.to_u64_digits()[0] as u32; - - out_buffer.set_biguint_target(&self.sum.value, &sum_reduced); - out_buffer.set_u32_target(self.overflow, overflow); - } -} - -#[derive(Debug)] -struct NonNativeSubtractionGenerator, const D: usize, FF: Field> { - a: NonNativeTarget, - b: NonNativeTarget, - diff: NonNativeTarget, - overflow: BoolTarget, - _phantom: PhantomData, -} - -impl, const D: usize, FF: PrimeField> SimpleGenerator - for NonNativeSubtractionGenerator -{ - fn dependencies(&self) -> Vec { - self.a - .value - .limbs - .iter() - .cloned() - .chain(self.b.value.limbs.clone()) - .map(|l| l.0) - .collect() - } - - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let a = FF::from_noncanonical_biguint(witness.get_biguint_target(self.a.value.clone())); - let b = FF::from_noncanonical_biguint(witness.get_biguint_target(self.b.value.clone())); - let a_biguint = a.to_canonical_biguint(); - let b_biguint = b.to_canonical_biguint(); - - let modulus = FF::order(); - let (diff_biguint, overflow) = if a_biguint >= b_biguint { - (a_biguint - b_biguint, false) - } else { - (modulus + a_biguint - b_biguint, true) - }; - - out_buffer.set_biguint_target(&self.diff.value, &diff_biguint); - out_buffer.set_bool_target(self.overflow, overflow); - } -} - -#[derive(Debug)] -struct NonNativeMultiplicationGenerator, const D: usize, FF: Field> { - a: NonNativeTarget, - b: NonNativeTarget, - prod: NonNativeTarget, - overflow: BigUintTarget, - _phantom: PhantomData, -} - -impl, const D: usize, FF: PrimeField> SimpleGenerator - for NonNativeMultiplicationGenerator -{ - fn dependencies(&self) -> Vec { - self.a - .value - .limbs - .iter() - .cloned() - .chain(self.b.value.limbs.clone()) - .map(|l| l.0) - .collect() - } - - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let a = FF::from_noncanonical_biguint(witness.get_biguint_target(self.a.value.clone())); - let b = FF::from_noncanonical_biguint(witness.get_biguint_target(self.b.value.clone())); - let a_biguint = a.to_canonical_biguint(); - let b_biguint = b.to_canonical_biguint(); - - let prod_biguint = a_biguint * b_biguint; - - let modulus = FF::order(); - let (overflow_biguint, prod_reduced) = prod_biguint.div_rem(&modulus); - - out_buffer.set_biguint_target(&self.prod.value, &prod_reduced); - out_buffer.set_biguint_target(&self.overflow, &overflow_biguint); - } -} - -#[derive(Debug)] -struct NonNativeInverseGenerator, const D: usize, FF: PrimeField> { - x: NonNativeTarget, - inv: BigUintTarget, - div: BigUintTarget, - _phantom: PhantomData, -} - -impl, const D: usize, FF: PrimeField> SimpleGenerator - for NonNativeInverseGenerator -{ - fn dependencies(&self) -> Vec { - self.x.value.limbs.iter().map(|&l| l.0).collect() - } - - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let x = FF::from_noncanonical_biguint(witness.get_biguint_target(self.x.value.clone())); - let inv = x.inverse(); - - let x_biguint = x.to_canonical_biguint(); - let inv_biguint = inv.to_canonical_biguint(); - let prod = x_biguint * &inv_biguint; - let modulus = FF::order(); - let (div, _rem) = prod.div_rem(&modulus); - - out_buffer.set_biguint_target(&self.div, &div); - out_buffer.set_biguint_target(&self.inv, &inv_biguint); - } -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - use plonky2::field::secp256k1_base::Secp256K1Base; - use plonky2::field::types::{Field, PrimeField, Sample}; - use plonky2::iop::witness::PartialWitness; - use plonky2::plonk::circuit_builder::CircuitBuilder; - use plonky2::plonk::circuit_data::CircuitConfig; - use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - - use crate::gadgets::nonnative::CircuitBuilderNonNative; - - #[test] - fn test_nonnative_add() -> Result<()> { - type FF = Secp256K1Base; - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let x_ff = FF::rand(); - let y_ff = FF::rand(); - let sum_ff = x_ff + y_ff; - - let config = CircuitConfig::standard_ecc_config(); - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let x = builder.constant_nonnative(x_ff); - let y = builder.constant_nonnative(y_ff); - let sum = builder.add_nonnative(&x, &y); - - let sum_expected = builder.constant_nonnative(sum_ff); - builder.connect_nonnative(&sum, &sum_expected); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - data.verify(proof) - } - - #[test] - fn test_nonnative_many_adds() -> Result<()> { - type FF = Secp256K1Base; - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let a_ff = FF::rand(); - let b_ff = FF::rand(); - let c_ff = FF::rand(); - let d_ff = FF::rand(); - let e_ff = FF::rand(); - let f_ff = FF::rand(); - let g_ff = FF::rand(); - let h_ff = FF::rand(); - let sum_ff = a_ff + b_ff + c_ff + d_ff + e_ff + f_ff + g_ff + h_ff; - - let config = CircuitConfig::standard_ecc_config(); - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let a = builder.constant_nonnative(a_ff); - let b = builder.constant_nonnative(b_ff); - let c = builder.constant_nonnative(c_ff); - let d = builder.constant_nonnative(d_ff); - let e = builder.constant_nonnative(e_ff); - let f = builder.constant_nonnative(f_ff); - let g = builder.constant_nonnative(g_ff); - let h = builder.constant_nonnative(h_ff); - let all = [a, b, c, d, e, f, g, h]; - let sum = builder.add_many_nonnative(&all); - - let sum_expected = builder.constant_nonnative(sum_ff); - builder.connect_nonnative(&sum, &sum_expected); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - data.verify(proof) - } - - #[test] - fn test_nonnative_sub() -> Result<()> { - type FF = Secp256K1Base; - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let x_ff = FF::rand(); - let mut y_ff = FF::rand(); - while y_ff.to_canonical_biguint() > x_ff.to_canonical_biguint() { - y_ff = FF::rand(); - } - let diff_ff = x_ff - y_ff; - - let config = CircuitConfig::standard_ecc_config(); - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let x = builder.constant_nonnative(x_ff); - let y = builder.constant_nonnative(y_ff); - let diff = builder.sub_nonnative(&x, &y); - - let diff_expected = builder.constant_nonnative(diff_ff); - builder.connect_nonnative(&diff, &diff_expected); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - data.verify(proof) - } - - #[test] - fn test_nonnative_mul() -> Result<()> { - type FF = Secp256K1Base; - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - let x_ff = FF::rand(); - let y_ff = FF::rand(); - let product_ff = x_ff * y_ff; - - let config = CircuitConfig::standard_ecc_config(); - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let x = builder.constant_nonnative(x_ff); - let y = builder.constant_nonnative(y_ff); - let product = builder.mul_nonnative(&x, &y); - - let product_expected = builder.constant_nonnative(product_ff); - builder.connect_nonnative(&product, &product_expected); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - data.verify(proof) - } - - #[test] - fn test_nonnative_neg() -> Result<()> { - type FF = Secp256K1Base; - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - let x_ff = FF::rand(); - let neg_x_ff = -x_ff; - - let config = CircuitConfig::standard_ecc_config(); - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let x = builder.constant_nonnative(x_ff); - let neg_x = builder.neg_nonnative(&x); - - let neg_x_expected = builder.constant_nonnative(neg_x_ff); - builder.connect_nonnative(&neg_x, &neg_x_expected); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - data.verify(proof) - } - - #[test] - fn test_nonnative_inv() -> Result<()> { - type FF = Secp256K1Base; - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - let x_ff = FF::rand(); - let inv_x_ff = x_ff.inverse(); - - let config = CircuitConfig::standard_ecc_config(); - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let x = builder.constant_nonnative(x_ff); - let inv_x = builder.inv_nonnative(&x); - - let inv_x_expected = builder.constant_nonnative(inv_x_ff); - builder.connect_nonnative(&inv_x, &inv_x_expected); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - data.verify(proof) - } -} diff --git a/ecdsa/src/gadgets/split_nonnative.rs b/ecdsa/src/gadgets/split_nonnative.rs deleted file mode 100644 index 977912e2..00000000 --- a/ecdsa/src/gadgets/split_nonnative.rs +++ /dev/null @@ -1,131 +0,0 @@ -use alloc::vec::Vec; -use core::marker::PhantomData; - -use itertools::Itertools; -use plonky2::field::extension::Extendable; -use plonky2::field::types::Field; -use plonky2::hash::hash_types::RichField; -use plonky2::iop::target::Target; -use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2_u32::gadgets::arithmetic_u32::{CircuitBuilderU32, U32Target}; - -use crate::gadgets::biguint::BigUintTarget; -use crate::gadgets::nonnative::NonNativeTarget; - -pub trait CircuitBuilderSplit, const D: usize> { - fn split_u32_to_4_bit_limbs(&mut self, val: U32Target) -> Vec; - - fn split_nonnative_to_4_bit_limbs( - &mut self, - val: &NonNativeTarget, - ) -> Vec; - - fn split_nonnative_to_2_bit_limbs( - &mut self, - val: &NonNativeTarget, - ) -> Vec; - - // Note: assumes its inputs are 4-bit limbs, and does not range-check. - fn recombine_nonnative_4_bit_limbs( - &mut self, - limbs: Vec, - ) -> NonNativeTarget; -} - -impl, const D: usize> CircuitBuilderSplit - for CircuitBuilder -{ - fn split_u32_to_4_bit_limbs(&mut self, val: U32Target) -> Vec { - let two_bit_limbs = self.split_le_base::<4>(val.0, 16); - let four = self.constant(F::from_canonical_usize(4)); - let combined_limbs = two_bit_limbs - .iter() - .tuples() - .map(|(&a, &b)| self.mul_add(b, four, a)) - .collect(); - - combined_limbs - } - - fn split_nonnative_to_4_bit_limbs( - &mut self, - val: &NonNativeTarget, - ) -> Vec { - val.value - .limbs - .iter() - .flat_map(|&l| self.split_u32_to_4_bit_limbs(l)) - .collect() - } - - fn split_nonnative_to_2_bit_limbs( - &mut self, - val: &NonNativeTarget, - ) -> Vec { - val.value - .limbs - .iter() - .flat_map(|&l| self.split_le_base::<4>(l.0, 16)) - .collect() - } - - // Note: assumes its inputs are 4-bit limbs, and does not range-check. - fn recombine_nonnative_4_bit_limbs( - &mut self, - limbs: Vec, - ) -> NonNativeTarget { - let base = self.constant_u32(1 << 4); - let u32_limbs = limbs - .chunks(8) - .map(|chunk| { - let mut combined_chunk = self.zero_u32(); - for i in (0..8).rev() { - let (low, _high) = self.mul_add_u32(combined_chunk, base, U32Target(chunk[i])); - combined_chunk = low; - } - combined_chunk - }) - .collect(); - - NonNativeTarget { - value: BigUintTarget { limbs: u32_limbs }, - _phantom: PhantomData, - } - } -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - use plonky2::field::secp256k1_scalar::Secp256K1Scalar; - use plonky2::field::types::Sample; - use plonky2::iop::witness::PartialWitness; - use plonky2::plonk::circuit_data::CircuitConfig; - use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - - use super::*; - use crate::gadgets::nonnative::{CircuitBuilderNonNative, NonNativeTarget}; - - #[test] - fn test_split_nonnative() -> Result<()> { - type FF = Secp256K1Scalar; - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_ecc_config(); - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let x = FF::rand(); - let x_target = builder.constant_nonnative(x); - let split = builder.split_nonnative_to_4_bit_limbs(&x_target); - let combined: NonNativeTarget = - builder.recombine_nonnative_4_bit_limbs(split); - builder.connect_nonnative(&x_target, &combined); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - data.verify(proof) - } -} diff --git a/ecdsa/src/lib.rs b/ecdsa/src/lib.rs deleted file mode 100644 index bf84913a..00000000 --- a/ecdsa/src/lib.rs +++ /dev/null @@ -1,7 +0,0 @@ -#![allow(clippy::needless_range_loop)] -#![cfg_attr(not(test), no_std)] - -extern crate alloc; - -pub mod curve; -pub mod gadgets; diff --git a/evm/src/bn254_arithmetic.rs b/evm/src/bn254_arithmetic.rs index d1050560..c2f1e3d4 100644 --- a/evm/src/bn254_arithmetic.rs +++ b/evm/src/bn254_arithmetic.rs @@ -27,9 +27,9 @@ impl Fp { impl Distribution for Standard { fn sample(&self, rng: &mut R) -> Fp { - let (x0, x1, x2, x3) = rng.gen::<(u64, u64, u64, u64)>(); + let xs = rng.gen::<[u64; 4]>(); Fp { - val: U256([x0, x1, x2, x3]) % BN_BASE, + val: U256(xs) % BN_BASE, } } } diff --git a/evm/src/cpu/kernel/aggregator.rs b/evm/src/cpu/kernel/aggregator.rs index cd494c36..aff40034 100644 --- a/evm/src/cpu/kernel/aggregator.rs +++ b/evm/src/cpu/kernel/aggregator.rs @@ -12,6 +12,7 @@ pub static KERNEL: Lazy = Lazy::new(combined_kernel); pub(crate) fn combined_kernel() -> Kernel { let files = vec![ include_str!("asm/core/bootloader.asm"), + include_str!("asm/core/call.asm"), include_str!("asm/core/create.asm"), include_str!("asm/core/create_addresses.asm"), include_str!("asm/core/intrinsic_gas.asm"), @@ -97,6 +98,7 @@ pub(crate) fn combined_kernel() -> Kernel { include_str!("asm/rlp/num_bytes.asm"), include_str!("asm/rlp/read_to_memory.asm"), include_str!("asm/shift.asm"), + include_str!("asm/transactions/common_decoding.asm"), include_str!("asm/transactions/router.asm"), include_str!("asm/transactions/type_0.asm"), include_str!("asm/transactions/type_1.asm"), diff --git a/evm/src/cpu/kernel/asm/account_code.asm b/evm/src/cpu/kernel/asm/account_code.asm index ebe1a5c9..f10fbc19 100644 --- a/evm/src/cpu/kernel/asm/account_code.asm +++ b/evm/src/cpu/kernel/asm/account_code.asm @@ -24,32 +24,49 @@ global extcodehash: %eq_const(@EMPTY_STRING_HASH) %endmacro -%macro codesize - // stack: (empty) - %address - %extcodesize -%endmacro - %macro extcodesize %stack (address) -> (address, 0, @SEGMENT_KERNEL_ACCOUNT_CODE, %%after) %jump(load_code) %%after: %endmacro +global sys_extcodesize: + // stack: kexit_info, address + SWAP1 + // stack: address, kexit_info + %extcodesize + // stack: code_size, kexit_info + SWAP1 + EXIT_KERNEL + global extcodesize: // stack: address, retdest %extcodesize // stack: extcodesize(address), retdest SWAP1 JUMP - %macro codecopy - // stack: dest_offset, offset, size, retdest + // stack: dest_offset, offset, size %address - // stack: address, dest_offset, offset, size, retdest - %jump(extcodecopy) + %extcodecopy %endmacro +%macro extcodecopy + // stack: address, dest_offset, offset, size + %stack (dest_offset, offset, size) -> (dest_offset, offset, size, %%after) + %jump(extcodecopy) +%%after: +%endmacro + +// Pre stack: kexit_info, address, dest_offset, offset, size +// Post stack: (empty) +global sys_extcodecopy: + %stack (kexit_info, address, dest_offset, offset, size) + -> (address, dest_offset, offset, size, kexit_info) + %extcodecopy + // stack: kexit_info + EXIT_KERNEL + // Pre stack: address, dest_offset, offset, size, retdest // Post stack: (empty) global extcodecopy: @@ -59,91 +76,92 @@ global extcodecopy: %jump(load_code) extcodecopy_contd: - // stack: code_length, size, offset, dest_offset, retdest + // stack: code_size, size, offset, dest_offset, retdest SWAP1 - // stack: size, code_length, offset, dest_offset, retdest + // stack: size, code_size, offset, dest_offset, retdest PUSH 0 // Loop copying the `code[offset]` to `memory[dest_offset]` until `i==size`. // Each iteration increments `offset, dest_offset, i`. // TODO: Consider implementing this with memcpy. extcodecopy_loop: - // stack: i, size, code_length, offset, dest_offset, retdest + // stack: i, size, code_size, offset, dest_offset, retdest DUP2 DUP2 EQ - // stack: i == size, i, size, code_length, offset, dest_offset, retdest + // stack: i == size, i, size, code_size, offset, dest_offset, retdest %jumpi(extcodecopy_end) - %stack (i, size, code_length, offset, dest_offset, retdest) - -> (offset, code_length, offset, code_length, dest_offset, i, size, retdest) + %stack (i, size, code_size, offset, dest_offset, retdest) + -> (offset, code_size, offset, code_size, dest_offset, i, size, retdest) LT - // stack: offset < code_length, offset, code_length, dest_offset, i, size, retdest + // stack: offset < code_size, offset, code_size, dest_offset, i, size, retdest DUP2 - // stack: offset, offset < code_length, offset, code_length, dest_offset, i, size, retdest + // stack: offset, offset < code_size, offset, code_size, dest_offset, i, size, retdest %mload_current(@SEGMENT_KERNEL_ACCOUNT_CODE) - // stack: opcode, offset < code_length, offset, code_length, dest_offset, i, size, retdest - %stack (opcode, offset_lt_code_length, offset, code_length, dest_offset, i, size, retdest) - -> (offset_lt_code_length, 0, opcode, offset, code_length, dest_offset, i, size, retdest) - // If `offset >= code_length`, use `opcode=0`. Necessary since `SEGMENT_KERNEL_ACCOUNT_CODE` might be clobbered from previous calls. + // stack: opcode, offset < code_size, offset, code_size, dest_offset, i, size, retdest + %stack (opcode, offset_lt_code_size, offset, code_size, dest_offset, i, size, retdest) + -> (offset_lt_code_size, 0, opcode, offset, code_size, dest_offset, i, size, retdest) + // If `offset >= code_size`, use `opcode=0`. Necessary since `SEGMENT_KERNEL_ACCOUNT_CODE` might be clobbered from previous calls. %select_bool - // stack: opcode, offset, code_length, dest_offset, i, size, retdest + // stack: opcode, offset, code_size, dest_offset, i, size, retdest DUP4 - // stack: dest_offset, opcode, offset, code_length, dest_offset, i, size, retdest + // stack: dest_offset, opcode, offset, code_size, dest_offset, i, size, retdest %mstore_main - // stack: offset, code_length, dest_offset, i, size, retdest + // stack: offset, code_size, dest_offset, i, size, retdest %increment - // stack: offset+1, code_length, dest_offset, i, size, retdest + // stack: offset+1, code_size, dest_offset, i, size, retdest SWAP2 - // stack: dest_offset, code_length, offset+1, i, size, retdest + // stack: dest_offset, code_size, offset+1, i, size, retdest %increment - // stack: dest_offset+1, code_length, offset+1, i, size, retdest + // stack: dest_offset+1, code_size, offset+1, i, size, retdest SWAP3 - // stack: i, code_length, offset+1, dest_offset+1, size, retdest + // stack: i, code_size, offset+1, dest_offset+1, size, retdest %increment - // stack: i+1, code_length, offset+1, dest_offset+1, size, retdest - %stack (i, code_length, offset, dest_offset, size, retdest) -> (i, size, code_length, offset, dest_offset, retdest) + // stack: i+1, code_size, offset+1, dest_offset+1, size, retdest + %stack (i, code_size, offset, dest_offset, size, retdest) -> (i, size, code_size, offset, dest_offset, retdest) %jump(extcodecopy_loop) extcodecopy_end: - %stack (i, size, code_length, offset, dest_offset, retdest) -> (retdest) + %stack (i, size, code_size, offset, dest_offset, retdest) -> (retdest) JUMP // Loads the code at `address` into memory, at the given context and segment, starting at offset 0. // Checks that the hash of the loaded code corresponds to the `codehash` in the state trie. // Pre stack: address, ctx, segment, retdest -// Post stack: code_len +// Post stack: code_size global load_code: %stack (address, ctx, segment, retdest) -> (extcodehash, address, load_code_ctd, ctx, segment, retdest) JUMP load_code_ctd: // stack: codehash, ctx, segment, retdest PROVER_INPUT(account_code::length) - // stack: code_length, codehash, ctx, segment, retdest + // stack: code_size, codehash, ctx, segment, retdest PUSH 0 -// Loop non-deterministically querying `code[i]` and storing it in `SEGMENT_KERNEL_ACCOUNT_CODE` at offset `i`, until `i==code_length`. +// Loop non-deterministically querying `code[i]` and storing it in `SEGMENT_KERNEL_ACCOUNT_CODE` +// at offset `i`, until `i==code_size`. load_code_loop: - // stack: i, code_length, codehash, ctx, segment, retdest + // stack: i, code_size, codehash, ctx, segment, retdest DUP2 DUP2 EQ - // stack: i == code_length, i, code_length, codehash, ctx, segment, retdest + // stack: i == code_size, i, code_size, codehash, ctx, segment, retdest %jumpi(load_code_check) PROVER_INPUT(account_code::get) - // stack: opcode, i, code_length, codehash, ctx, segment, retdest + // stack: opcode, i, code_size, codehash, ctx, segment, retdest DUP2 - // stack: i, opcode, i, code_length, codehash, ctx, segment, retdest + // stack: i, opcode, i, code_size, codehash, ctx, segment, retdest DUP7 // segment DUP7 // context MSTORE_GENERAL - // stack: i, code_length, codehash, ctx, segment, retdest + // stack: i, code_size, codehash, ctx, segment, retdest %increment - // stack: i+1, code_length, codehash, ctx, segment, retdest + // stack: i+1, code_size, codehash, ctx, segment, retdest %jump(load_code_loop) // Check that the hash of the loaded code equals `codehash`. load_code_check: - // stack: i, code_length, codehash, ctx, segment, retdest - %stack (i, code_length, codehash, ctx, segment, retdest) - -> (ctx, segment, 0, code_length, codehash, retdest, code_length) + // stack: i, code_size, codehash, ctx, segment, retdest + %stack (i, code_size, codehash, ctx, segment, retdest) + -> (ctx, segment, 0, code_size, codehash, retdest, code_size) KECCAK_GENERAL - // stack: shouldbecodehash, codehash, retdest, code_length + // stack: shouldbecodehash, codehash, retdest, code_size %assert_eq JUMP diff --git a/evm/src/cpu/kernel/asm/core/call.asm b/evm/src/cpu/kernel/asm/core/call.asm index 761ffc7d..198a6cbb 100644 --- a/evm/src/cpu/kernel/asm/core/call.asm +++ b/evm/src/cpu/kernel/asm/core/call.asm @@ -1,113 +1,178 @@ // Handlers for call-like operations, namely CALL, CALLCODE, STATICCALL and DELEGATECALL. // Creates a new sub context and executes the code of the given account. -global call: - // stack: gas, address, value, args_offset, args_size, ret_offset, ret_size, retdest - %address - %stack (self, gas, address, value) - // These are (static, should_transfer_value, value, sender, address, code_addr, gas) - -> (0, 1, value, self, address, address, gas) - %jump(call_common) +global sys_call: + // stack: kexit_info, gas, address, value, args_offset, args_size, ret_offset, ret_size + %create_context + // stack: new_ctx, kexit_info, gas, address, value, args_offset, args_size, ret_offset, ret_size + + // Each line in the block below does not change the stack. + DUP4 %set_new_ctx_addr + %address %set_new_ctx_caller + DUP5 %set_new_ctx_value + DUP5 DUP5 %address %transfer_eth + %set_new_ctx_parent_ctx + %set_new_ctx_parent_pc(after_call_instruction) + + // TODO: Copy memory[args_offset..args_offset + args_size] CALLDATA + // TODO: Set child gas + // TODO: Populate code and codesize field. + + // stack: new_ctx, kexit_info, gas, address, value, args_offset, args_size, ret_offset, ret_size + %stack (new_ctx, kexit_info, gas, address, value, args_offset, args_size, ret_offset, ret_size) + -> (new_ctx, kexit_info, ret_offset, ret_size) + %enter_new_ctx // Creates a new sub context as if calling itself, but with the code of the // given account. In particular the storage remains the same. -global call_code: - // stack: gas, address, value, args_offset, args_size, ret_offset, ret_size, retdest - %address - %stack (self, gas, address, value) - // These are (static, should_transfer_value, value, sender, address, code_addr, gas) - -> (0, 1, value, self, self, address, gas) - %jump(call_common) +global sys_callcode: + // stack: kexit_info, gas, address, value, args_offset, args_size, ret_offset, ret_size + %create_context + // stack: new_ctx, kexit_info, gas, address, value, args_offset, args_size, ret_offset, ret_size + + // Each line in the block below does not change the stack. + %address %set_new_ctx_addr + %address %set_new_ctx_caller + DUP5 %set_new_ctx_value + DUP5 DUP5 %address %transfer_eth + %set_new_ctx_parent_ctx + %set_new_ctx_parent_pc(after_call_instruction) + + // stack: new_ctx, kexit_info, gas, address, value, args_offset, args_size, ret_offset, ret_size + %stack (new_ctx, kexit_info, gas, address, value, args_offset, args_size, ret_offset, ret_size) + -> (new_ctx, kexit_info, ret_offset, ret_size) + %enter_new_ctx // Creates a new sub context and executes the code of the given account. // Equivalent to CALL, except that it does not allow any state modifying // instructions or sending ETH in the sub context. The disallowed instructions // are CREATE, CREATE2, LOG0, LOG1, LOG2, LOG3, LOG4, SSTORE, SELFDESTRUCT and // CALL if the value sent is not 0. -global static_all: - // stack: gas, address, args_offset, args_size, ret_offset, ret_size, retdest - %address - %stack (self, gas, address) - // These are (static, should_transfer_value, value, sender, address, code_addr, gas) - -> (1, 0, 0, self, address, address, gas) - %jump(call_common) +global sys_staticcall: + // stack: kexit_info, gas, address, args_offset, args_size, ret_offset, ret_size + %create_context + // stack: new_ctx, kexit_info, gas, address, args_offset, args_size, ret_offset, ret_size + + // Each line in the block below does not change the stack. + %set_static_true + DUP4 %set_new_ctx_addr + %address %set_new_ctx_caller + PUSH 0 %set_new_ctx_value + %set_new_ctx_parent_ctx + %set_new_ctx_parent_pc(after_call_instruction) + + %stack (new_ctx, kexit_info, gas, address, args_offset, args_size, ret_offset, ret_size) + -> (new_ctx, kexit_info, ret_offset, ret_size) + %enter_new_ctx // Creates a new sub context as if calling itself, but with the code of the // given account. In particular the storage, the current sender and the current // value remain the same. -global delegate_call: - // stack: gas, address, args_offset, args_size, ret_offset, ret_size, retdest - %address - %sender - %callvalue - %stack (self, sender, value, gas, address) - // These are (static, should_transfer_value, value, sender, address, code_addr, gas) - -> (0, 0, value, sender, self, address, gas) - %jump(call_common) - -// Pre stack: static, should_transfer_value, value, sender, address, code_addr, gas, args_offset, args_size, ret_offset, ret_size, retdest -// Post stack: success, leftover_gas -global call_common: - // stack: static, should_transfer_value, value, sender, address, code_addr, gas, args_offset, args_size, ret_offset, ret_size, retdest +global sys_delegatecall: + // stack: kexit_info, gas, address, args_offset, args_size, ret_offset, ret_size %create_context - // Store the static flag in metadata. - %stack (new_ctx, static) -> (new_ctx, @SEGMENT_CONTEXT_METADATA, @CTX_METADATA_STATIC, static, new_ctx) + // stack: new_ctx, kexit_info, gas, address, args_offset, args_size, ret_offset, ret_size + + // Each line in the block below does not change the stack. + %address %set_new_ctx_addr + %caller %set_new_ctx_caller + %callvalue %set_new_ctx_value + %set_new_ctx_parent_ctx + %set_new_ctx_parent_pc(after_call_instruction) + + %stack (new_ctx, kexit_info, gas, address, args_offset, args_size, ret_offset, ret_size) + -> (new_ctx, kexit_info, ret_offset, ret_size) + %enter_new_ctx + +// We go here after any CALL type instruction (but not after the special call by the transaction originator). +global after_call_instruction: + // stack: success, leftover_gas, new_ctx, kexit_info, ret_offset, ret_size + SWAP3 + // stack: kexit_info, leftover_gas, new_ctx, success, ret_offset, ret_size + // Add the leftover gas into the appropriate bits of kexit_info. + SWAP1 %shl_const(192) ADD + // stack: kexit_info, new_ctx, success, ret_offset, ret_size + + // The callee's terminal instruction will have populated RETURNDATA. + // TODO: Copy RETURNDATA to memory[ret_offset..ret_offset + ret_size]. + + %stack (kexit_info, new_ctx, success, ret_offset, ret_size) + -> (kexit_info, success) + EXIT_KERNEL + +// Set @CTX_METADATA_STATIC to 1. Note that there is no corresponding set_static_false routine +// because it will already be 0 by default. +%macro set_static_true + // stack: new_ctx + %stack (new_ctx) -> (new_ctx, @SEGMENT_CONTEXT_METADATA, @CTX_METADATA_STATIC, 1, new_ctx) MSTORE_GENERAL - // stack: new_ctx, should_transfer_value, value, sender, address, code_addr, gas, args_offset, args_size, ret_offset, ret_size, retdest + // stack: new_ctx +%endmacro - // Store the address in metadata. - %stack (new_ctx, should_transfer_value, value, sender, address) - -> (new_ctx, @SEGMENT_CONTEXT_METADATA, @CTX_METADATA_ADDRESS, address, - new_ctx, should_transfer_value, value, sender, address) +%macro set_new_ctx_addr + // stack: called_addr, new_ctx + %stack (called_addr, new_ctx) + -> (new_ctx, @SEGMENT_CONTEXT_METADATA, @CTX_METADATA_ADDRESS, called_addr, new_ctx) MSTORE_GENERAL - // stack: new_ctx, should_transfer_value, value, sender, address, code_addr, gas, args_offset, args_size, ret_offset, ret_size, retdest + // stack: new_ctx +%endmacro - // Store the caller in metadata. - %stack (new_ctx, should_transfer_value, value, sender) - -> (new_ctx, @SEGMENT_CONTEXT_METADATA, @CTX_METADATA_CALLER, sender, - new_ctx, should_transfer_value, value, sender) +%macro set_new_ctx_caller + // stack: sender, new_ctx + %stack (sender, new_ctx) + -> (new_ctx, @SEGMENT_CONTEXT_METADATA, @CTX_METADATA_CALLER, sender, new_ctx) MSTORE_GENERAL - // stack: new_ctx, should_transfer_value, value, sender, address, code_addr, gas, args_offset, args_size, ret_offset, ret_size, retdest + // stack: new_ctx +%endmacro - // Store the call value field in metadata. - %stack (new_ctx, should_transfer_value, value, sender, address) = - -> (new_ctx, @SEGMENT_CONTEXT_METADATA, @CTX_METADATA_CALL_VALUE, value, - should_transfer_value, sender, address, value, new_ctx) +%macro set_new_ctx_value + // stack: value, new_ctx + %stack (value, new_ctx) + -> (new_ctx, @SEGMENT_CONTEXT_METADATA, @CTX_METADATA_CALL_VALUE, value, new_ctx) MSTORE_GENERAL - // stack: should_transfer_value, sender, address, value, new_ctx, code_addr, gas, args_offset, args_size, ret_offset, ret_size, retdest + // stack: new_ctx +%endmacro - %maybe_transfer_eth - // stack: new_ctx, code_addr, gas, args_offset, args_size, ret_offset, ret_size, retdest +%macro set_new_ctx_code_size + // stack: code_size, new_ctx + %stack (code_size, new_ctx) + -> (new_ctx, @SEGMENT_CONTEXT_METADATA, @CTX_METADATA_CODE_SIZE, code_size, new_ctx) + MSTORE_GENERAL + // stack: new_ctx +%endmacro - // Store parent context in metadata. +%macro set_new_ctx_gas_limit + // stack: gas_limit, new_ctx + %stack (gas_limit, new_ctx) + -> (new_ctx, @SEGMENT_CONTEXT_METADATA, @CTX_METADATA_CODE_SIZE, gas_limit, new_ctx) + MSTORE_GENERAL + // stack: new_ctx +%endmacro + +%macro set_new_ctx_parent_ctx + // stack: new_ctx GET_CONTEXT PUSH @CTX_METADATA_PARENT_CONTEXT PUSH @SEGMENT_CONTEXT_METADATA DUP4 // new_ctx MSTORE_GENERAL - // stack: new_ctx, code_addr, gas, args_offset, args_size, ret_offset, ret_size, retdest + // stack: new_ctx +%endmacro - // Store parent PC = after_call. - %stack (new_ctx) -> (new_ctx, @SEGMENT_CONTEXT_METADATA, @CTX_METADATA_PARENT_PC, after_call, new_ctx) +%macro set_new_ctx_parent_pc(label) + // stack: new_ctx + %stack (new_ctx) + -> (new_ctx, @SEGMENT_CONTEXT_METADATA, @CTX_METADATA_PARENT_PC, $label, new_ctx) MSTORE_GENERAL - // stack: new_ctx, code_addr, gas, args_offset, args_size, ret_offset, ret_size, retdest + // stack: new_ctx +%endmacro - // TODO: Populate CALLDATA - // TODO: Save parent gas and set child gas - // TODO: Populate code - - // TODO: Temporary, remove after above steps are done. - %stack (new_ctx, code_addr, gas, args_offset, args_size) -> (new_ctx) - // stack: new_ctx, ret_offset, ret_size, retdest - - // Now, switch to the new context and go to usermode with PC=0. +%macro enter_new_ctx + // stack: new_ctx + // Switch to the new context and go to usermode with PC=0. DUP1 // new_ctx SET_CONTEXT PUSH 0 // jump dest EXIT_KERNEL - -after_call: - // stack: new_ctx, ret_offset, ret_size, retdest - // TODO: Set RETURNDATA. - // TODO: Return to caller w/ EXIT_KERNEL. - // TODO: Return leftover gas + // (Old context) stack: new_ctx +%endmacro diff --git a/evm/src/cpu/kernel/asm/core/create.asm b/evm/src/cpu/kernel/asm/core/create.asm index e2552ff5..eb0f821e 100644 --- a/evm/src/cpu/kernel/asm/core/create.asm +++ b/evm/src/cpu/kernel/asm/core/create.asm @@ -26,14 +26,18 @@ global create: // CREATE2; see EIP-1014. Address will be // address = KEC(0xff || sender || salt || code_hash)[12:] // -// Pre stack: sender, endowment, salt, CODE_ADDR, code_len, retdest +// Pre stack: sender, endowment, salt, CODE_ADDR: 3, code_len, retdest // Post stack: address // Note: CODE_ADDR refers to a (context, segment, offset) tuple. global sys_create2: - // stack: sender, endowment, salt, CODE_ADDR, code_len, retdest + // stack: sender, endowment, salt, CODE_ADDR: 3, code_len, retdest + DUP7 DUP7 DUP7 DUP7 // CODE_ADDR: 3, code_len + KECCAK_GENERAL + // stack: code_hash, sender, endowment, salt, CODE_ADDR: 3, code_len, retdest + // Call get_create2_address and have it return to create_inner. - %stack (sender, endowment, salt, CODE_ADDR: 3, code_len) - -> (sender, salt, CODE_ADDR, code_len, create_inner, sender, endowment, CODE_ADDR, code_len) + %stack (code_hash, sender, endowment, salt) + -> (sender, salt, code_hash, create_inner, sender, endowment) // stack: sender, salt, CODE_ADDR, code_len, create_inner, sender, endowment, CODE_ADDR, code_len, retdest %jump(get_create2_address) diff --git a/evm/src/cpu/kernel/asm/core/create_addresses.asm b/evm/src/cpu/kernel/asm/core/create_addresses.asm index ceda8b13..2d94ee94 100644 --- a/evm/src/cpu/kernel/asm/core/create_addresses.asm +++ b/evm/src/cpu/kernel/asm/core/create_addresses.asm @@ -14,14 +14,12 @@ global get_create_address: // Computes the address for a contract based on the CREATE2 rule, i.e. // address = KEC(0xff || sender || salt || code_hash)[12:] // -// Pre stack: sender, salt, CODE_ADDR, code_len, retdest +// Pre stack: sender, salt, code_hash, retdest // Post stack: address -// -// Note: CODE_ADDR is a (context, segment, offset) tuple. global get_create2_address: - // stack: sender, salt, CODE_ADDR, code_len, retdest + // stack: sender, salt, code_hash, retdest // TODO: Replace with actual implementation. - %pop6 + %pop3 PUSH 123 SWAP1 JUMP diff --git a/evm/src/cpu/kernel/asm/core/process_txn.asm b/evm/src/cpu/kernel/asm/core/process_txn.asm index 6adc0831..17770ba4 100644 --- a/evm/src/cpu/kernel/asm/core/process_txn.asm +++ b/evm/src/cpu/kernel/asm/core/process_txn.asm @@ -29,6 +29,7 @@ global validate: // TODO: Assert nonce is correct. // TODO: Assert sender has no code. // TODO: Assert sender balance >= gas_limit * gas_price + value. + // TODO: Assert chain ID matches block metadata? // stack: retdest global buy_gas: @@ -38,7 +39,7 @@ global buy_gas: // stack: gas_cost, retdest %mload_txn_field(@TXN_FIELD_ORIGIN) // stack: sender_addr, gas_cost, retdest - %deduct_eth + %deduct_eth // TODO: It should be transferred to coinbase instead? // stack: deduct_eth_status, retdest global txn_failure_insufficient_balance: %jumpi(panic) @@ -109,68 +110,32 @@ global process_message_txn_insufficient_balance: PANIC // TODO global process_message_txn_return: - // TODO: Return leftover gas? + // TODO: Since there was no code to execute, do we still return leftover gas? JUMP global process_message_txn_code_loaded: - // stack: code_len, new_ctx, retdest - POP + // stack: code_size, new_ctx, retdest + %set_new_ctx_code_size // stack: new_ctx, retdest - // Store the address in metadata. - %mload_txn_field(@TXN_FIELD_TO) - PUSH @CTX_METADATA_ADDRESS - PUSH @SEGMENT_CONTEXT_METADATA - DUP4 // new_ctx - MSTORE_GENERAL + // Each line in the block below does not change the stack. + %mload_txn_field(@TXN_FIELD_TO) %set_new_ctx_addr + %mload_txn_field(@TXN_FIELD_ORIGIN) %set_new_ctx_caller + %mload_txn_field(@TXN_FIELD_VALUE) %set_new_ctx_value + %set_new_ctx_parent_ctx + %set_new_ctx_parent_pc(process_message_txn_after_call) + %mload_txn_field(@TXN_FIELD_GAS_LIMIT) %set_new_ctx_gas_limit // stack: new_ctx, retdest - // Store the caller in metadata. - %mload_txn_field(@TXN_FIELD_ORIGIN) - PUSH @CTX_METADATA_CALLER - PUSH @SEGMENT_CONTEXT_METADATA - DUP4 // new_ctx - MSTORE_GENERAL - // stack: new_ctx, retdest + // TODO: Copy TXN_DATA to CALLDATA - // Store the call value field in metadata. - %mload_txn_field(@TXN_FIELD_VALUE) - PUSH @CTX_METADATA_CALL_VALUE - PUSH @SEGMENT_CONTEXT_METADATA - DUP4 // new_ctx - MSTORE_GENERAL - // stack: new_ctx, retdest - - // No need to write @CTX_METADATA_STATIC, because it's 0 which is the default. - - // Store parent context in metadata. - GET_CONTEXT - PUSH @CTX_METADATA_PARENT_CONTEXT - PUSH @SEGMENT_CONTEXT_METADATA - DUP4 // new_ctx - MSTORE_GENERAL - // stack: new_ctx, retdest - - // Store parent PC = process_message_txn_after_call. - PUSH process_message_txn_after_call - PUSH @CTX_METADATA_PARENT_PC - PUSH @SEGMENT_CONTEXT_METADATA - DUP4 // new_ctx - MSTORE_GENERAL - // stack: new_ctx, retdest - - // TODO: Populate CALLDATA - - // TODO: Save parent gas and set child gas - - // Now, switch to the new context and go to usermode with PC=0. - SET_CONTEXT - // stack: retdest - PUSH 0 // jump dest - EXIT_KERNEL + %enter_new_ctx global process_message_txn_after_call: - // stack: success, retdest - // TODO: Return leftover gas? Or handled by termination instructions? - POP // Pop success for now. Will go into the receipt when we support that. + // stack: success, leftover_gas, new_ctx, retdest + POP // TODO: Success will go into the receipt when we support that. + // stack: leftover_gas, new_ctx, retdest + POP // TODO: Refund leftover gas. + // stack: new_ctx, retdest + POP JUMP diff --git a/evm/src/cpu/kernel/asm/core/syscall_stubs.asm b/evm/src/cpu/kernel/asm/core/syscall_stubs.asm index d7f5b912..6dcbbb6e 100644 --- a/evm/src/cpu/kernel/asm/core/syscall_stubs.asm +++ b/evm/src/cpu/kernel/asm/core/syscall_stubs.asm @@ -13,32 +13,24 @@ global sys_sgt: PANIC global sys_sar: PANIC -global sys_address: - PANIC global sys_balance: PANIC global sys_origin: PANIC -global sys_caller: - PANIC -global sys_callvalue: - PANIC global sys_calldataload: PANIC global sys_calldatasize: PANIC global sys_calldatacopy: PANIC -global sys_codesize: - PANIC global sys_codecopy: PANIC global sys_gasprice: - PANIC -global sys_extcodesize: - PANIC -global sys_extcodecopy: - PANIC + // stack: kexit_info + %mload_txn_field(@TXN_FIELD_COMPUTED_FEE_PER_GAS) + // stack: gas_price, kexit_info + SWAP1 + EXIT_KERNEL global sys_returndatasize: PANIC global sys_returndatacopy: @@ -54,19 +46,32 @@ global sys_timestamp: global sys_number: PANIC global sys_prevrandao: + // TODO: What semantics will this have for Edge? PANIC global sys_gaslimit: + // TODO: Return the block's gas limit. PANIC global sys_chainid: - PANIC + // TODO: Return the block's chain ID instead of the txn's, even though they should match. + // stack: kexit_info + %mload_txn_field(@TXN_FIELD_CHAIN_ID) + // stack: chain_id, kexit_info + SWAP1 + EXIT_KERNEL global sys_selfbalance: PANIC global sys_basefee: PANIC -global sys_msize: - PANIC global sys_gas: - PANIC + // stack: kexit_info + DUP1 %shr_const(192) + // stack: gas_used, kexit_info + %mload_context_metadata(@CTX_METADATA_GAS_LIMIT) + // stack: gas_limit, gas_used, kexit_info + SUB + // stack: gas_remaining, kexit_info + SWAP1 + EXIT_KERNEL global sys_log0: PANIC global sys_log1: @@ -77,11 +82,3 @@ global sys_log3: PANIC global sys_log4: PANIC -global sys_call: - PANIC -global sys_callcode: - PANIC -global sys_delegatecall: - PANIC -global sys_staticcall: - PANIC diff --git a/evm/src/cpu/kernel/asm/core/terminate.asm b/evm/src/cpu/kernel/asm/core/terminate.asm index e3b88d3d..341884ea 100644 --- a/evm/src/cpu/kernel/asm/core/terminate.asm +++ b/evm/src/cpu/kernel/asm/core/terminate.asm @@ -2,29 +2,38 @@ // RETURN, SELFDESTRUCT, REVERT, and exceptions such as stack underflow. global sys_stop: + // stack: kexit_info + %leftover_gas + // stack: leftover_gas // TODO: Set parent context's CTX_METADATA_RETURNDATA_SIZE to 0. - // TODO: Refund unused gas to parent. PUSH 1 // success %jump(terminate_common) global sys_return: + // stack: kexit_info + %leftover_gas + // stack: leftover_gas // TODO: Set parent context's CTX_METADATA_RETURNDATA_SIZE. - // TODO: Copy returned memory to parent context's RETURNDATA (but not if we're returning from a constructor?) - // TODO: Copy returned memory to parent context's memory (as specified in their call instruction) - // TODO: Refund unused gas to parent. + // TODO: Copy returned memory to parent context's RETURNDATA. PUSH 1 // success %jump(terminate_common) global sys_selfdestruct: + // stack: kexit_info %consume_gas_const(@GAS_SELFDESTRUCT) + %leftover_gas + // stack: leftover_gas // TODO: Destroy account. - // TODO: Refund unused gas to parent. PUSH 1 // success %jump(terminate_common) global sys_revert: - // TODO: Refund unused gas to parent. + // stack: kexit_info + %leftover_gas + // stack: leftover_gas // TODO: Revert state changes. + // TODO: Set parent context's CTX_METADATA_RETURNDATA_SIZE. + // TODO: Copy returned memory to parent context's RETURNDATA. PUSH 0 // success %jump(terminate_common) @@ -36,24 +45,36 @@ global sys_revert: // - the new stack size would be larger than 1024, or // - state modification is attempted during a static call global fault_exception: + // stack: (empty) + PUSH 0 // leftover_gas // TODO: Revert state changes. + // TODO: Set parent context's CTX_METADATA_RETURNDATA_SIZE to 0. PUSH 0 // success %jump(terminate_common) -terminate_common: - // stack: success +global terminate_common: + // stack: success, leftover_gas + // TODO: Panic if we exceeded our gas limit? + // We want to move the success flag from our (child) context's stack to the // parent context's stack. We will write it to memory, specifically // SEGMENT_KERNEL_GENERAL[0], then load it after the context switch. PUSH 0 - // stack: 0, success + // stack: 0, success, leftover_gas + %mstore_kernel_general + // stack: leftover_gas + + // Similarly, we write leftover_gas to SEGMENT_KERNEL_GENERAL[1] so that + // we can later read it after switching to the parent context. + PUSH 1 + // stack: 1, leftover_gas %mstore_kernel_general // stack: (empty) - // Similarly, we write the parent PC to SEGMENT_KERNEL_GENERAL[1] so that + // Similarly, we write the parent PC to SEGMENT_KERNEL_GENERAL[2] so that // we can later read it after switching to the parent context. %mload_context_metadata(@CTX_METADATA_PARENT_PC) - PUSH 1 + PUSH 2 %mstore_kernel(@SEGMENT_KERNEL_GENERAL) // stack: (empty) @@ -62,9 +83,19 @@ terminate_common: SET_CONTEXT // stack: (empty) - // Load the success flag and parent PC that we stored in SEGMENT_KERNEL_GENERAL. - PUSH 0 %mload_kernel_general - PUSH 1 %mload_kernel_general + // Load the fields that we stored in SEGMENT_KERNEL_GENERAL. + PUSH 1 %mload_kernel_general // leftover_gas + PUSH 0 %mload_kernel_general // success + PUSH 2 %mload_kernel_general // parent_pc - // stack: parent_pc, success + // stack: parent_pc, success, leftover_gas JUMP + +%macro leftover_gas + // stack: kexit_info + %shr_const(192) + // stack: gas_used + %mload_context_metadata(@CTX_METADATA_GAS_LIMIT) + SUB + // stack: leftover_gas +%endmacro diff --git a/evm/src/cpu/kernel/asm/main.asm b/evm/src/cpu/kernel/asm/main.asm index c6e818b2..e87b790c 100644 --- a/evm/src/cpu/kernel/asm/main.asm +++ b/evm/src/cpu/kernel/asm/main.asm @@ -6,7 +6,7 @@ global main: PUSH hash_initial_tries %jump(load_all_mpts) -hash_initial_tries: +global hash_initial_tries: %mpt_hash_state_trie %mstore_global_metadata(@GLOBAL_METADATA_STATE_TRIE_DIGEST_BEFORE) %mpt_hash_txn_trie %mstore_global_metadata(@GLOBAL_METADATA_TXN_TRIE_DIGEST_BEFORE) %mpt_hash_receipt_trie %mstore_global_metadata(@GLOBAL_METADATA_RECEIPT_TRIE_DIGEST_BEFORE) diff --git a/evm/src/cpu/kernel/asm/memory/metadata.asm b/evm/src/cpu/kernel/asm/memory/metadata.asm index 1a495682..7ea6d9e5 100644 --- a/evm/src/cpu/kernel/asm/memory/metadata.asm +++ b/evm/src/cpu/kernel/asm/memory/metadata.asm @@ -38,18 +38,57 @@ %mload_context_metadata(@CTX_METADATA_ADDRESS) %endmacro -%macro sender +global sys_address: + // stack: kexit_info + %address + // stack: address, kexit_info + SWAP1 + EXIT_KERNEL + +%macro caller %mload_context_metadata(@CTX_METADATA_CALLER) %endmacro +global sys_caller: + // stack: kexit_info + %caller + // stack: caller, kexit_info + SWAP1 + EXIT_KERNEL + %macro callvalue %mload_context_metadata(@CTX_METADATA_CALL_VALUE) %endmacro +%macro codesize + %mload_context_metadata(@CTX_METADATA_CODE_SIZE) +%endmacro + +global sys_codesize: + // stack: kexit_info + %codesize + // stack: codesize, kexit_info + SWAP1 + EXIT_KERNEL + +global sys_callvalue: + // stack: kexit_info + %callvalue + // stack: callvalue, kexit_info + SWAP1 + EXIT_KERNEL + %macro msize %mload_context_metadata(@CTX_METADATA_MSIZE) %endmacro +global sys_msize: + // stack: kexit_info + %msize + // stack: msize, kexit_info + SWAP1 + EXIT_KERNEL + %macro update_msize // stack: offset %add_const(32) @@ -64,4 +103,3 @@ // stack: new_msize %mstore_context_metadata(@CTX_METADATA_MSIZE) %endmacro - diff --git a/evm/src/cpu/kernel/asm/mpt/storage/storage_write.asm b/evm/src/cpu/kernel/asm/mpt/storage/storage_write.asm index a56117a7..f56a5fdf 100644 --- a/evm/src/cpu/kernel/asm/mpt/storage/storage_write.asm +++ b/evm/src/cpu/kernel/asm/mpt/storage/storage_write.asm @@ -4,6 +4,12 @@ // Post stack: (empty) global sys_sstore: + // TODO: Assuming a cold zero -> nonzero write for now. + PUSH @GAS_COLDSLOAD + PUSH @GAS_SSET + ADD + %charge_gas + %stack (kexit_info, slot, value) -> (slot, value, kexit_info) // TODO: If value = 0, delete the key instead of inserting 0. // stack: slot, value, kexit_info diff --git a/evm/src/cpu/kernel/asm/transactions/common_decoding.asm b/evm/src/cpu/kernel/asm/transactions/common_decoding.asm new file mode 100644 index 00000000..71440d1c --- /dev/null +++ b/evm/src/cpu/kernel/asm/transactions/common_decoding.asm @@ -0,0 +1,139 @@ +// Store chain ID = 1. Used for non-legacy txns which always have a chain ID. +%macro store_chain_id_present_true + PUSH 1 + %mstore_txn_field(@TXN_FIELD_CHAIN_ID_PRESENT) +%endmacro + +// Decode the chain ID and store it. +%macro decode_and_store_chain_id + // stack: pos + %decode_rlp_scalar + %stack (pos, chain_id) -> (chain_id, pos) + %mstore_txn_field(@TXN_FIELD_CHAIN_ID) + // stack: pos +%endmacro + +// Decode the nonce and store it. +%macro decode_and_store_nonce + // stack: pos + %decode_rlp_scalar + %stack (pos, nonce) -> (nonce, pos) + %mstore_txn_field(@TXN_FIELD_NONCE) + // stack: pos +%endmacro + +// Decode the gas price and, since this is for legacy txns, store it as both +// TXN_FIELD_MAX_PRIORITY_FEE_PER_GAS and TXN_FIELD_MAX_FEE_PER_GAS. +%macro decode_and_store_gas_price_legacy + // stack: pos + %decode_rlp_scalar + %stack (pos, gas_price) -> (gas_price, gas_price, pos) + %mstore_txn_field(@TXN_FIELD_MAX_PRIORITY_FEE_PER_GAS) + %mstore_txn_field(@TXN_FIELD_MAX_FEE_PER_GAS) + // stack: pos +%endmacro + +// Decode the max priority fee and store it. +%macro decode_and_store_max_priority_fee + // stack: pos + %decode_rlp_scalar + %stack (pos, gas_price) -> (gas_price, pos) + %mstore_txn_field(@TXN_FIELD_MAX_PRIORITY_FEE_PER_GAS) + // stack: pos +%endmacro + +// Decode the max fee and store it. +%macro decode_and_store_max_fee + // stack: pos + %decode_rlp_scalar + %stack (pos, gas_price) -> (gas_price, pos) + %mstore_txn_field(@TXN_FIELD_MAX_FEE_PER_GAS) + // stack: pos +%endmacro + +// Decode the gas limit and store it. +%macro decode_and_store_gas_limit + // stack: pos + %decode_rlp_scalar + %stack (pos, gas_limit) -> (gas_limit, pos) + %mstore_txn_field(@TXN_FIELD_GAS_LIMIT) + // stack: pos +%endmacro + +// Decode the "to" field and store it. +%macro decode_and_store_to + // stack: pos + %decode_rlp_scalar + %stack (pos, to) -> (to, pos) + %mstore_txn_field(@TXN_FIELD_TO) + // stack: pos +%endmacro + +// Decode the "value" field and store it. +%macro decode_and_store_value + // stack: pos + %decode_rlp_scalar + %stack (pos, value) -> (value, pos) + %mstore_txn_field(@TXN_FIELD_VALUE) + // stack: pos +%endmacro + +// Decode the calldata field, store its length in @TXN_FIELD_DATA_LEN, and copy it to @SEGMENT_TXN_DATA. +%macro decode_and_store_data + // stack: pos + // Decode the data length, store it, and compute new_pos after any data. + %decode_rlp_string_len + %stack (pos, data_len) -> (data_len, pos, data_len, pos, data_len) + %mstore_txn_field(@TXN_FIELD_DATA_LEN) + // stack: pos, data_len, pos, data_len + ADD + // stack: new_pos, old_pos, data_len + + // Memcpy the txn data from @SEGMENT_RLP_RAW to @SEGMENT_TXN_DATA. + %stack (new_pos, old_pos, data_len) -> (old_pos, data_len, %%after, new_pos) + PUSH @SEGMENT_RLP_RAW + GET_CONTEXT + PUSH 0 + PUSH @SEGMENT_TXN_DATA + GET_CONTEXT + // stack: DST, SRC, data_len, %%after, new_pos + %jump(memcpy) + +%%after: + // stack: new_pos +%endmacro + +%macro decode_and_store_access_list + // stack: pos + %decode_rlp_list_len + %stack (pos, len) -> (len, pos) + %jumpi(todo_access_lists_not_supported_yet) + // stack: pos +%endmacro + +%macro decode_and_store_y_parity + // stack: pos + %decode_rlp_scalar + %stack (pos, y_parity) -> (y_parity, pos) + %mstore_txn_field(@TXN_FIELD_Y_PARITY) + // stack: pos +%endmacro + +%macro decode_and_store_r + // stack: pos + %decode_rlp_scalar + %stack (pos, r) -> (r, pos) + %mstore_txn_field(@TXN_FIELD_R) + // stack: pos +%endmacro + +%macro decode_and_store_s + // stack: pos + %decode_rlp_scalar + %stack (pos, s) -> (s, pos) + %mstore_txn_field(@TXN_FIELD_S) + // stack: pos +%endmacro + +global todo_access_lists_not_supported_yet: + PANIC diff --git a/evm/src/cpu/kernel/asm/transactions/type_0.asm b/evm/src/cpu/kernel/asm/transactions/type_0.asm index d1f00ed9..e9aedca0 100644 --- a/evm/src/cpu/kernel/asm/transactions/type_0.asm +++ b/evm/src/cpu/kernel/asm/transactions/type_0.asm @@ -19,61 +19,16 @@ global process_type_0_txn: // We don't actually need the length. %stack (pos, len) -> (pos) - // Decode the nonce and store it. // stack: pos, retdest - %decode_rlp_scalar - %stack (pos, nonce) -> (nonce, pos) - %mstore_txn_field(@TXN_FIELD_NONCE) - - // Decode the gas price and store it. - // For legacy transactions, we set both the - // TXN_FIELD_MAX_PRIORITY_FEE_PER_GAS and TXN_FIELD_MAX_FEE_PER_GAS - // fields to gas_price. + %decode_and_store_nonce + %decode_and_store_gas_price_legacy + %decode_and_store_gas_limit + %decode_and_store_to + %decode_and_store_value + %decode_and_store_data // stack: pos, retdest - %decode_rlp_scalar - %stack (pos, gas_price) -> (gas_price, gas_price, pos) - %mstore_txn_field(@TXN_FIELD_MAX_PRIORITY_FEE_PER_GAS) - %mstore_txn_field(@TXN_FIELD_MAX_FEE_PER_GAS) - // Decode the gas limit and store it. - // stack: pos, retdest - %decode_rlp_scalar - %stack (pos, gas_limit) -> (gas_limit, pos) - %mstore_txn_field(@TXN_FIELD_GAS_LIMIT) - - // Decode the "to" field and store it. - // stack: pos, retdest - %decode_rlp_scalar - %stack (pos, to) -> (to, pos) - %mstore_txn_field(@TXN_FIELD_TO) - - // Decode the value field and store it. - // stack: pos, retdest - %decode_rlp_scalar - %stack (pos, value) -> (value, pos) - %mstore_txn_field(@TXN_FIELD_VALUE) - - // Decode the data length, store it, and compute new_pos after any data. - // stack: pos, retdest - %decode_rlp_string_len - %stack (pos, data_len) -> (data_len, pos, data_len, pos, data_len) - %mstore_txn_field(@TXN_FIELD_DATA_LEN) - // stack: pos, data_len, pos, data_len, retdest - ADD - // stack: new_pos, pos, data_len, retdest - - // Memcpy the txn data from @SEGMENT_RLP_RAW to @SEGMENT_TXN_DATA. - PUSH parse_v - %stack (parse_v, new_pos, old_pos, data_len) -> (old_pos, data_len, parse_v, new_pos) - PUSH @SEGMENT_RLP_RAW - GET_CONTEXT - PUSH 0 - PUSH @SEGMENT_TXN_DATA - GET_CONTEXT - // stack: DST, SRC, data_len, parse_v, new_pos, retdest - %jump(memcpy) - -parse_v: + // Parse the "v" field. // stack: pos, retdest %decode_rlp_scalar // stack: pos, v, retdest @@ -93,7 +48,7 @@ parse_v: %mstore_txn_field(@TXN_FIELD_Y_PARITY) // stack: pos, retdest - %jump(parse_r) + %jump(decode_r_and_s) process_v_new_style: // stack: v, pos, retdest @@ -115,16 +70,12 @@ process_v_new_style: // stack: y_parity, pos, retdest %mstore_txn_field(@TXN_FIELD_Y_PARITY) -parse_r: +decode_r_and_s: // stack: pos, retdest - %decode_rlp_scalar - %stack (pos, r) -> (r, pos) - %mstore_txn_field(@TXN_FIELD_R) - + %decode_and_store_r + %decode_and_store_s // stack: pos, retdest - %decode_rlp_scalar - %stack (pos, s) -> (s) - %mstore_txn_field(@TXN_FIELD_S) + POP // stack: retdest type_0_compute_signed_data: diff --git a/evm/src/cpu/kernel/asm/transactions/type_1.asm b/evm/src/cpu/kernel/asm/transactions/type_1.asm index 8c7fcaae..fbd934d9 100644 --- a/evm/src/cpu/kernel/asm/transactions/type_1.asm +++ b/evm/src/cpu/kernel/asm/transactions/type_1.asm @@ -8,4 +8,29 @@ global process_type_1_txn: // stack: retdest - PANIC // TODO: Unfinished + PUSH 1 // initial pos, skipping over the 0x01 byte + // stack: pos, retdest + %decode_rlp_list_len + // We don't actually need the length. + %stack (pos, len) -> (pos) + + %store_chain_id_present_true + %decode_and_store_chain_id + %decode_and_store_nonce + %decode_and_store_gas_price_legacy + %decode_and_store_gas_limit + %decode_and_store_to + %decode_and_store_value + %decode_and_store_data + %decode_and_store_access_list + %decode_and_store_y_parity + %decode_and_store_r + %decode_and_store_s + + // stack: pos, retdest + POP + // stack: retdest + + // TODO: Check signature. + + %jump(process_normalized_txn) diff --git a/evm/src/cpu/kernel/asm/transactions/type_2.asm b/evm/src/cpu/kernel/asm/transactions/type_2.asm index f1ff18d8..d9586c21 100644 --- a/evm/src/cpu/kernel/asm/transactions/type_2.asm +++ b/evm/src/cpu/kernel/asm/transactions/type_2.asm @@ -9,4 +9,31 @@ global process_type_2_txn: // stack: retdest - PANIC // TODO: Unfinished + PUSH 1 // initial pos, skipping over the 0x02 byte + // stack: pos, retdest + %decode_rlp_list_len + // We don't actually need the length. + %stack (pos, len) -> (pos) + + // stack: pos, retdest + %store_chain_id_present_true + %decode_and_store_chain_id + %decode_and_store_nonce + %decode_and_store_max_priority_fee + %decode_and_store_max_fee + %decode_and_store_gas_limit + %decode_and_store_to + %decode_and_store_value + %decode_and_store_data + %decode_and_store_access_list + %decode_and_store_y_parity + %decode_and_store_r + %decode_and_store_s + + // stack: pos, retdest + POP + // stack: retdest + + // TODO: Check signature. + + %jump(process_normalized_txn) diff --git a/evm/src/cpu/kernel/assembler.rs b/evm/src/cpu/kernel/assembler.rs index 15fb1d4b..2afd328f 100644 --- a/evm/src/cpu/kernel/assembler.rs +++ b/evm/src/cpu/kernel/assembler.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::time::Instant; use ethereum_types::U256; use itertools::{izip, Itertools}; @@ -111,6 +112,7 @@ pub(crate) fn assemble( let mut local_labels = Vec::with_capacity(files.len()); let mut macro_counter = 0; for file in files { + let start = Instant::now(); let mut file = file.body; file = expand_macros(file, ¯os, &mut macro_counter); file = inline_constants(file, &constants); @@ -125,6 +127,7 @@ pub(crate) fn assemble( &mut prover_inputs, )); expanded_files.push(file); + debug!("Expanding file took {:?}", start.elapsed()); } let mut code = vec![]; for (file, locals) in izip!(expanded_files, local_labels) { @@ -134,6 +137,7 @@ pub(crate) fn assemble( debug!("Assembled file size: {} bytes", file_len); } assert_eq!(code.len(), offset, "Code length doesn't match offset."); + debug!("Total kernel size: {} bytes", code.len()); Kernel::new(code, global_labels, prover_inputs) } diff --git a/evm/src/cpu/kernel/constants/context_metadata.rs b/evm/src/cpu/kernel/constants/context_metadata.rs index 27bec078..4e869661 100644 --- a/evm/src/cpu/kernel/constants/context_metadata.rs +++ b/evm/src/cpu/kernel/constants/context_metadata.rs @@ -26,10 +26,12 @@ pub(crate) enum ContextMetadata { /// Size of the active main memory. MSize = 10, StackSize = 11, + /// The gas limit for this call (not the entire transaction). + GasLimit = 12, } impl ContextMetadata { - pub(crate) const COUNT: usize = 12; + pub(crate) const COUNT: usize = 13; pub(crate) fn all() -> [Self; Self::COUNT] { [ @@ -45,6 +47,7 @@ impl ContextMetadata { Self::StateTrieCheckpointPointer, Self::MSize, Self::StackSize, + Self::GasLimit, ] } @@ -63,6 +66,7 @@ impl ContextMetadata { ContextMetadata::StateTrieCheckpointPointer => "CTX_METADATA_STATE_TRIE_CHECKPOINT_PTR", ContextMetadata::MSize => "CTX_METADATA_MSIZE", ContextMetadata::StackSize => "CTX_METADATA_STACK_SIZE", + ContextMetadata::GasLimit => "CTX_METADATA_GAS_LIMIT", } } } diff --git a/evm/src/cpu/kernel/constants/trie_type.rs b/evm/src/cpu/kernel/constants/trie_type.rs index 08fd8748..30f4802b 100644 --- a/evm/src/cpu/kernel/constants/trie_type.rs +++ b/evm/src/cpu/kernel/constants/trie_type.rs @@ -1,5 +1,6 @@ use eth_trie_utils::partial_trie::PartialTrie; +#[derive(Copy, Clone)] pub(crate) enum PartialTrieType { Empty = 0, Hash = 1, diff --git a/evm/src/cpu/kernel/tests/core/create_addresses.rs b/evm/src/cpu/kernel/tests/core/create_addresses.rs index c77ff937..047ddc00 100644 --- a/evm/src/cpu/kernel/tests/core/create_addresses.rs +++ b/evm/src/cpu/kernel/tests/core/create_addresses.rs @@ -28,23 +28,12 @@ fn test_get_create2_address() -> Result<()> { // TODO: Replace with real data once we have a real implementation. let retaddr = 0xdeadbeefu32.into(); - let code_len = 0.into(); - let code_offset = 0.into(); - let code_segment = 0.into(); - let code_context = 0.into(); + let code_hash = 0.into(); let salt = 5.into(); let sender = 0.into(); let expected_addr = 123.into(); - let initial_stack = vec![ - retaddr, - code_len, - code_offset, - code_segment, - code_context, - salt, - sender, - ]; + let initial_stack = vec![retaddr, code_hash, salt, sender]; let mut interpreter = Interpreter::new_with_kernel(get_create2_address, initial_stack); interpreter.run()?; diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index 858bb111..9bde0106 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -18,7 +18,10 @@ use crate::config::StarkConfig; use crate::cpu::bootstrap_kernel::generate_bootstrap_kernel; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; +use crate::cpu::kernel::constants::global_metadata::GlobalMetadata::StateTrieRoot; +use crate::generation::mpt::AccountRlp; use crate::generation::state::GenerationState; +use crate::generation::trie_extractor::read_state_trie_value; use crate::memory::segments::Segment; use crate::proof::{BlockMetadata, PublicValues, TrieRoots}; use crate::witness::memory::MemoryAddress; @@ -28,6 +31,8 @@ pub mod mpt; pub(crate) mod prover_input; pub(crate) mod rlp; pub(crate) mod state; +mod trie_extractor; +use crate::generation::trie_extractor::read_trie; #[derive(Clone, Debug, Deserialize, Serialize, Default)] /// Inputs needed for trace generation. @@ -67,12 +72,12 @@ pub(crate) fn generate_traces, const D: usize>( inputs: GenerationInputs, config: &StarkConfig, timing: &mut TimingTree, -) -> ([Vec>; NUM_TABLES], PublicValues) { +) -> anyhow::Result<([Vec>; NUM_TABLES], PublicValues)> { let mut state = GenerationState::::new(inputs.clone(), &KERNEL.code); generate_bootstrap_kernel::(&mut state); - timed!(timing, "simulate CPU", simulate_cpu(&mut state)); + timed!(timing, "simulate CPU", simulate_cpu(&mut state)?); log::info!( "Trace lengths (before padding): {:?}", @@ -87,6 +92,15 @@ pub(crate) fn generate_traces, const D: usize>( )) }; + log::debug!( + "Updated state trie:\n{:#?}", + read_trie::( + &state.memory, + read_metadata(StateTrieRoot).as_usize(), + read_state_trie_value + ) + ); + let trie_roots_before = TrieRoots { state_root: H256::from_uint(&read_metadata(StateTrieRootDigestBefore)), transactions_root: H256::from_uint(&read_metadata(TransactionTrieRootDigestBefore)), @@ -109,10 +123,12 @@ pub(crate) fn generate_traces, const D: usize>( "convert trace data to tables", state.traces.into_tables(all_stark, config, timing) ); - (tables, public_values) + Ok((tables, public_values)) } -fn simulate_cpu, const D: usize>(state: &mut GenerationState) { +fn simulate_cpu, const D: usize>( + state: &mut GenerationState, +) -> anyhow::Result<()> { let halt_pc0 = KERNEL.global_labels["halt_pc0"]; let halt_pc1 = KERNEL.global_labels["halt_pc1"]; @@ -126,11 +142,13 @@ fn simulate_cpu, const D: usize>(state: &mut Genera } already_in_halt_loop |= in_halt_loop; - transition(state); + transition(state)?; if already_in_halt_loop && state.traces.clock().is_power_of_two() { log::info!("CPU trace padded to {} cycles", state.traces.clock()); break; } } + + Ok(()) } diff --git a/evm/src/generation/state.rs b/evm/src/generation/state.rs index bf1fbd74..88f17ade 100644 --- a/evm/src/generation/state.rs +++ b/evm/src/generation/state.rs @@ -33,6 +33,15 @@ pub(crate) struct GenerationState { impl GenerationState { pub(crate) fn new(inputs: GenerationInputs, kernel_code: &[u8]) -> Self { + log::debug!("Input signed_txns: {:?}", &inputs.signed_txns); + log::debug!("Input state_trie: {:?}", &inputs.tries.state_trie); + log::debug!( + "Input transactions_trie: {:?}", + &inputs.tries.transactions_trie + ); + log::debug!("Input receipts_trie: {:?}", &inputs.tries.receipts_trie); + log::debug!("Input storage_tries: {:?}", &inputs.tries.storage_tries); + log::debug!("Input contract_code: {:?}", &inputs.contract_code); let mpt_prover_inputs = all_mpt_prover_inputs_reversed(&inputs.tries); let rlp_prover_inputs = all_rlp_prover_inputs_reversed(&inputs.signed_txns); diff --git a/evm/src/generation/trie_extractor.rs b/evm/src/generation/trie_extractor.rs new file mode 100644 index 00000000..d35d67eb --- /dev/null +++ b/evm/src/generation/trie_extractor.rs @@ -0,0 +1,98 @@ +use std::collections::HashMap; + +use eth_trie_utils::partial_trie::Nibbles; +use ethereum_types::{BigEndianHash, H256, U256}; +use plonky2::field::extension::Extendable; +use plonky2::hash::hash_types::RichField; + +use crate::cpu::kernel::constants::trie_type::PartialTrieType; +use crate::generation::mpt::AccountRlp; +use crate::memory::segments::Segment; +use crate::witness::memory::{MemoryAddress, MemoryState}; + +pub(crate) fn read_state_trie_value(slice: &[U256]) -> AccountRlp { + AccountRlp { + nonce: slice[0], + balance: slice[1], + storage_root: H256::from_uint(&slice[2]), + code_hash: H256::from_uint(&slice[3]), + } +} + +pub(crate) fn read_trie( + memory: &MemoryState, + ptr: usize, + read_value: fn(&[U256]) -> V, +) -> HashMap +where + F: RichField + Extendable, +{ + let mut res = HashMap::new(); + let empty_nibbles = Nibbles { + count: 0, + packed: U256::zero(), + }; + read_trie_helper::(memory, ptr, read_value, empty_nibbles, &mut res); + res +} + +pub(crate) fn read_trie_helper( + memory: &MemoryState, + ptr: usize, + read_value: fn(&[U256]) -> V, + prefix: Nibbles, + res: &mut HashMap, +) where + F: RichField + Extendable, +{ + let load = |offset| memory.get(MemoryAddress::new(0, Segment::TrieData, offset)); + let load_slice_from = |init_offset| { + &memory.contexts[0].segments[Segment::TrieData as usize].content[init_offset..] + }; + + let trie_type = PartialTrieType::all()[load(ptr).as_usize()]; + match trie_type { + PartialTrieType::Empty => {} + PartialTrieType::Hash => {} + PartialTrieType::Branch => { + let ptr_payload = ptr + 1; + for i in 0u8..16 { + let child_ptr = load(ptr_payload + i as usize).as_usize(); + read_trie_helper::( + memory, + child_ptr, + read_value, + prefix.merge_nibble(i), + res, + ); + } + let value_ptr = load(ptr_payload + 16).as_usize(); + if value_ptr != 0 { + res.insert(prefix, read_value(load_slice_from(value_ptr))); + }; + } + PartialTrieType::Extension => { + let count = load(ptr + 1).as_usize(); + let packed = load(ptr + 2); + let nibbles = Nibbles { count, packed }; + let child_ptr = load(ptr + 3).as_usize(); + read_trie_helper::( + memory, + child_ptr, + read_value, + prefix.merge_nibbles(&nibbles), + res, + ); + } + PartialTrieType::Leaf => { + let count = load(ptr + 1).as_usize(); + let packed = load(ptr + 2); + let nibbles = Nibbles { count, packed }; + let value_ptr = load(ptr + 3).as_usize(); + res.insert( + prefix.merge_nibbles(&nibbles), + read_value(load_slice_from(value_ptr)), + ); + } + } +} diff --git a/evm/src/prover.rs b/evm/src/prover.rs index c801950a..9e26218a 100644 --- a/evm/src/prover.rs +++ b/evm/src/prover.rs @@ -60,7 +60,7 @@ where let (traces, public_values) = timed!( timing, "generate all traces", - generate_traces(all_stark, inputs, config, timing) + generate_traces(all_stark, inputs, config, timing)? ); prove_with_traces(all_stark, config, traces, public_values, timing) } diff --git a/evm/src/witness/errors.rs b/evm/src/witness/errors.rs index bd4b03c9..53263675 100644 --- a/evm/src/witness/errors.rs +++ b/evm/src/witness/errors.rs @@ -7,4 +7,5 @@ pub enum ProgramError { InvalidJumpDestination, InvalidJumpiDestination, StackOverflow, + KernelPanic, } diff --git a/evm/src/witness/transition.rs b/evm/src/witness/transition.rs index d233655c..ff10b08b 100644 --- a/evm/src/witness/transition.rs +++ b/evm/src/witness/transition.rs @@ -1,3 +1,4 @@ +use anyhow::bail; use itertools::Itertools; use log::log_enabled; use plonky2::field::types::Field; @@ -117,10 +118,13 @@ fn decode(registers: RegistersState, opcode: u8) -> Result Ok(Operation::Syscall(opcode)), (0xa3, _) => Ok(Operation::Syscall(opcode)), (0xa4, _) => Ok(Operation::Syscall(opcode)), - (0xa5, _) => panic!( - "Kernel panic at {}", - KERNEL.offset_name(registers.program_counter) - ), + (0xa5, _) => { + log::warn!( + "Kernel panic at {}", + KERNEL.offset_name(registers.program_counter) + ); + Err(ProgramError::KernelPanic) + } (0xf0, _) => Ok(Operation::Syscall(opcode)), (0xf1, _) => Ok(Operation::Syscall(opcode)), (0xf2, _) => Ok(Operation::Syscall(opcode)), @@ -288,11 +292,11 @@ fn log_kernel_instruction(state: &mut GenerationState, op: Operatio assert!(pc < KERNEL.code.len(), "Kernel PC is out of range: {}", pc); } -fn handle_error(_state: &mut GenerationState) { - todo!("generation for exception handling is not implemented"); +fn handle_error(_state: &mut GenerationState) -> anyhow::Result<()> { + bail!("TODO: generation for exception handling is not implemented"); } -pub(crate) fn transition(state: &mut GenerationState) { +pub(crate) fn transition(state: &mut GenerationState) -> anyhow::Result<()> { let checkpoint = state.checkpoint(); let result = try_perform_instruction(state); @@ -301,11 +305,12 @@ pub(crate) fn transition(state: &mut GenerationState) { state .memory .apply_ops(state.traces.mem_ops_since(checkpoint.traces)); + Ok(()) } Err(e) => { if state.registers.is_kernel { let offset_name = KERNEL.offset_name(state.registers.program_counter); - panic!("exception in kernel mode at {}: {:?}", offset_name, e); + bail!("exception in kernel mode at {}: {:?}", offset_name, e); } state.rollback(checkpoint); handle_error(state) diff --git a/insertion/Cargo.toml b/insertion/Cargo.toml deleted file mode 100644 index 125e72e3..00000000 --- a/insertion/Cargo.toml +++ /dev/null @@ -1,13 +0,0 @@ -[package] -name = "plonky2_insertion" -description = "Circuit implementation of list insertion" -version = "0.1.0" -edition = "2021" - -[dependencies] -anyhow = { version = "1.0.40", default-features = false } -plonky2 = { version = "0.1.2", default-features = false } - -[dev-dependencies] -plonky2 = { version = "0.1.2" } - diff --git a/insertion/LICENSE-APACHE b/insertion/LICENSE-APACHE deleted file mode 100644 index 1e5006dc..00000000 --- a/insertion/LICENSE-APACHE +++ /dev/null @@ -1,202 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - -TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - -1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - -2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - -3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - -4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - -5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - -6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - -7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - -8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - -9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - -END OF TERMS AND CONDITIONS - -APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - -Copyright [yyyy] [name of copyright owner] - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - diff --git a/insertion/LICENSE-MIT b/insertion/LICENSE-MIT deleted file mode 100644 index 86d690b2..00000000 --- a/insertion/LICENSE-MIT +++ /dev/null @@ -1,21 +0,0 @@ -The MIT License (MIT) - -Copyright (c) 2022 The Plonky2 Authors - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. diff --git a/insertion/README.md b/insertion/README.md deleted file mode 100644 index bb4e2d8a..00000000 --- a/insertion/README.md +++ /dev/null @@ -1,13 +0,0 @@ -## License - -Licensed under either of - -* Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) -* MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) - -at your option. - - -### Contribution - -Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in the work by you, as defined in the Apache-2.0 license, shall be dual licensed as above, without any additional terms or conditions. diff --git a/insertion/src/insert_gadget.rs b/insertion/src/insert_gadget.rs deleted file mode 100644 index 1574d8fb..00000000 --- a/insertion/src/insert_gadget.rs +++ /dev/null @@ -1,112 +0,0 @@ -use alloc::vec; -use alloc::vec::Vec; - -use plonky2::field::extension::Extendable; -use plonky2::hash::hash_types::RichField; -use plonky2::iop::ext_target::ExtensionTarget; -use plonky2::iop::target::Target; -use plonky2::plonk::circuit_builder::CircuitBuilder; - -use crate::insertion_gate::InsertionGate; - -pub trait CircuitBuilderInsert, const D: usize> { - /// Inserts a `Target` in a vector at a non-deterministic index. - /// Note: `index` is not range-checked. - fn insert( - &mut self, - index: Target, - element: ExtensionTarget, - v: Vec>, - ) -> Vec>; -} - -impl, const D: usize> CircuitBuilderInsert - for CircuitBuilder -{ - fn insert( - &mut self, - index: Target, - element: ExtensionTarget, - v: Vec>, - ) -> Vec> { - let gate = InsertionGate::new(v.len()); - let row = self.add_gate(gate.clone(), vec![]); - - v.iter().enumerate().for_each(|(i, &val)| { - self.connect_extension( - val, - ExtensionTarget::from_range(row, gate.wires_original_list_item(i)), - ); - }); - self.connect(index, Target::wire(row, gate.wires_insertion_index())); - self.connect_extension( - element, - ExtensionTarget::from_range(row, gate.wires_element_to_insert()), - ); - - (0..=v.len()) - .map(|i| ExtensionTarget::from_range(row, gate.wires_output_list_item(i))) - .collect::>() - } -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - use plonky2::field::types::{Field, Sample}; - use plonky2::iop::witness::PartialWitness; - use plonky2::plonk::circuit_data::CircuitConfig; - use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - - use super::*; - - fn real_insert( - index: usize, - element: ExtensionTarget, - v: &[ExtensionTarget], - ) -> Vec> { - let mut res = v.to_vec(); - res.insert(index, element); - res - } - - fn test_insert_given_len(len_log: usize) -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - type FF = >::FE; - let len = 1 << len_log; - let config = CircuitConfig::standard_recursion_config(); - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - let v = (0..len - 1) - .map(|_| builder.constant_extension(FF::rand())) - .collect::>(); - - for i in 0..len { - let it = builder.constant(F::from_canonical_usize(i)); - let elem = builder.constant_extension(FF::rand()); - let inserted = real_insert(i, elem, &v); - let purported_inserted = builder.insert(it, elem, v.clone()); - - assert_eq!(inserted.len(), purported_inserted.len()); - - for (x, y) in inserted.into_iter().zip(purported_inserted) { - builder.connect_extension(x, y); - } - } - - let data = builder.build::(); - let proof = data.prove(pw)?; - - data.verify(proof) - } - - #[test] - fn test_insert() -> Result<()> { - for len_log in 1..3 { - test_insert_given_len(len_log)?; - } - Ok(()) - } -} diff --git a/insertion/src/insertion_gate.rs b/insertion/src/insertion_gate.rs deleted file mode 100644 index f9dc5fce..00000000 --- a/insertion/src/insertion_gate.rs +++ /dev/null @@ -1,424 +0,0 @@ -use alloc::boxed::Box; -use alloc::string::String; -use alloc::vec::Vec; -use alloc::{format, vec}; -use core::marker::PhantomData; -use core::ops::Range; - -use plonky2::field::extension::{Extendable, FieldExtension}; -use plonky2::field::types::Field; -use plonky2::gates::gate::Gate; -use plonky2::gates::util::StridedConstraintConsumer; -use plonky2::hash::hash_types::RichField; -use plonky2::iop::ext_target::ExtensionTarget; -use plonky2::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; -use plonky2::iop::target::Target; -use plonky2::iop::wire::Wire; -use plonky2::iop::witness::{PartitionWitness, Witness, WitnessWrite}; -use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; - -/// A gate for inserting a value into a list at a non-deterministic location. -#[derive(Clone, Debug)] -pub(crate) struct InsertionGate, const D: usize> { - pub vec_size: usize, - _phantom: PhantomData, -} - -impl, const D: usize> InsertionGate { - pub fn new(vec_size: usize) -> Self { - Self { - vec_size, - _phantom: PhantomData, - } - } - - pub fn wires_insertion_index(&self) -> usize { - 0 - } - - pub fn wires_element_to_insert(&self) -> Range { - 1..D + 1 - } - - pub fn wires_original_list_item(&self, i: usize) -> Range { - debug_assert!(i < self.vec_size); - let start = (i + 1) * D + 1; - start..start + D - } - - fn start_of_output_wires(&self) -> usize { - (self.vec_size + 1) * D + 1 - } - - pub fn wires_output_list_item(&self, i: usize) -> Range { - debug_assert!(i <= self.vec_size); - let start = self.start_of_output_wires() + i * D; - start..start + D - } - - fn start_of_intermediate_wires(&self) -> usize { - self.start_of_output_wires() + (self.vec_size + 1) * D - } - - /// An intermediate wire for a dummy variable used to show equality. - /// The prover sets this to 1/(x-y) if x != y, or to an arbitrary value if - /// x == y. - pub fn wire_equality_dummy_for_round_r(&self, r: usize) -> usize { - self.start_of_intermediate_wires() + r - } - - // An intermediate wire for the "insert_here" variable (1 if the current index is the index at - /// which to insert the new value, 0 otherwise). - pub fn wire_insert_here_for_round_r(&self, r: usize) -> usize { - self.start_of_intermediate_wires() + (self.vec_size + 1) + r - } -} - -impl, const D: usize> Gate for InsertionGate { - fn id(&self) -> String { - format!("{self:?}") - } - - fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { - let insertion_index = vars.local_wires[self.wires_insertion_index()]; - let list_items = (0..self.vec_size) - .map(|i| vars.get_local_ext_algebra(self.wires_original_list_item(i))) - .collect::>(); - let output_list_items = (0..=self.vec_size) - .map(|i| vars.get_local_ext_algebra(self.wires_output_list_item(i))) - .collect::>(); - let element_to_insert = vars.get_local_ext_algebra(self.wires_element_to_insert()); - - let mut constraints = Vec::with_capacity(self.num_constraints()); - let mut already_inserted = F::Extension::ZERO; - for r in 0..=self.vec_size { - let cur_index = F::Extension::from_canonical_usize(r); - let difference = cur_index - insertion_index; - let equality_dummy = vars.local_wires[self.wire_equality_dummy_for_round_r(r)]; - let insert_here = vars.local_wires[self.wire_insert_here_for_round_r(r)]; - - // The two equality constraints. - constraints.push(difference * equality_dummy - (F::Extension::ONE - insert_here)); - constraints.push(insert_here * difference); - - let mut new_item = element_to_insert.scalar_mul(insert_here); - if r > 0 { - new_item += list_items[r - 1].scalar_mul(already_inserted); - } - already_inserted += insert_here; - if r < self.vec_size { - new_item += list_items[r].scalar_mul(F::Extension::ONE - already_inserted); - } - - // Output constraint. - constraints.extend((new_item - output_list_items[r]).to_basefield_array()); - } - - constraints - } - - fn eval_unfiltered_base_one( - &self, - vars: EvaluationVarsBase, - mut yield_constr: StridedConstraintConsumer, - ) { - let insertion_index = vars.local_wires[self.wires_insertion_index()]; - let list_items = (0..self.vec_size) - .map(|i| vars.get_local_ext(self.wires_original_list_item(i))) - .collect::>(); - let output_list_items = (0..=self.vec_size) - .map(|i| vars.get_local_ext(self.wires_output_list_item(i))) - .collect::>(); - let element_to_insert = vars.get_local_ext(self.wires_element_to_insert()); - - let mut already_inserted = F::ZERO; - for r in 0..=self.vec_size { - let cur_index = F::from_canonical_usize(r); - let difference = cur_index - insertion_index; - let equality_dummy = vars.local_wires[self.wire_equality_dummy_for_round_r(r)]; - let insert_here = vars.local_wires[self.wire_insert_here_for_round_r(r)]; - - // The two equality constraints. - yield_constr.one(difference * equality_dummy - (F::ONE - insert_here)); - yield_constr.one(insert_here * difference); - - let mut new_item = element_to_insert.scalar_mul(insert_here); - if r > 0 { - new_item += list_items[r - 1].scalar_mul(already_inserted); - } - already_inserted += insert_here; - if r < self.vec_size { - new_item += list_items[r].scalar_mul(F::ONE - already_inserted); - } - - // Output constraint. - yield_constr.many((new_item - output_list_items[r]).to_basefield_array()); - } - } - - fn eval_unfiltered_circuit( - &self, - builder: &mut CircuitBuilder, - vars: EvaluationTargets, - ) -> Vec> { - let insertion_index = vars.local_wires[self.wires_insertion_index()]; - let list_items = (0..self.vec_size) - .map(|i| vars.get_local_ext_algebra(self.wires_original_list_item(i))) - .collect::>(); - let output_list_items = (0..=self.vec_size) - .map(|i| vars.get_local_ext_algebra(self.wires_output_list_item(i))) - .collect::>(); - let element_to_insert = vars.get_local_ext_algebra(self.wires_element_to_insert()); - - let mut constraints = Vec::with_capacity(self.num_constraints()); - let mut already_inserted = builder.constant_extension(F::Extension::ZERO); - for r in 0..=self.vec_size { - let cur_index_ext = F::Extension::from_canonical_usize(r); - let cur_index = builder.constant_extension(cur_index_ext); - - let difference = builder.sub_extension(cur_index, insertion_index); - let equality_dummy = vars.local_wires[self.wire_equality_dummy_for_round_r(r)]; - let insert_here = vars.local_wires[self.wire_insert_here_for_round_r(r)]; - - // The two equality constraints. - let prod = builder.mul_extension(difference, equality_dummy); - let one = builder.constant_extension(F::Extension::ONE); - let not_insert_here = builder.sub_extension(one, insert_here); - let first_equality_constraint = builder.sub_extension(prod, not_insert_here); - constraints.push(first_equality_constraint); - - let second_equality_constraint = builder.mul_extension(insert_here, difference); - constraints.push(second_equality_constraint); - - let mut new_item = builder.scalar_mul_ext_algebra(insert_here, element_to_insert); - if r > 0 { - new_item = builder.scalar_mul_add_ext_algebra( - already_inserted, - list_items[r - 1], - new_item, - ); - } - already_inserted = builder.add_extension(already_inserted, insert_here); - if r < self.vec_size { - let not_already_inserted = builder.sub_extension(one, already_inserted); - new_item = builder.scalar_mul_add_ext_algebra( - not_already_inserted, - list_items[r], - new_item, - ); - } - - // Output constraint. - let diff = builder.sub_ext_algebra(new_item, output_list_items[r]); - constraints.extend(diff.to_ext_target_array()); - } - - constraints - } - - fn generators(&self, row: usize, _local_constants: &[F]) -> Vec>> { - let gen = InsertionGenerator:: { - row, - gate: self.clone(), - }; - vec![Box::new(gen.adapter())] - } - - fn num_wires(&self) -> usize { - self.wire_insert_here_for_round_r(self.vec_size) + 1 - } - - fn num_constants(&self) -> usize { - 0 - } - - fn degree(&self) -> usize { - 2 - } - - fn num_constraints(&self) -> usize { - (self.vec_size + 1) * (2 + D) - } -} - -#[derive(Debug)] -struct InsertionGenerator, const D: usize> { - row: usize, - gate: InsertionGate, -} - -impl, const D: usize> SimpleGenerator for InsertionGenerator { - fn dependencies(&self) -> Vec { - let local_target = |column| Target::wire(self.row, column); - - let local_targets = |columns: Range| columns.map(local_target); - - let mut deps = vec![local_target(self.gate.wires_insertion_index())]; - deps.extend(local_targets(self.gate.wires_element_to_insert())); - for i in 0..self.gate.vec_size { - deps.extend(local_targets(self.gate.wires_original_list_item(i))); - } - deps - } - - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let local_wire = |column| Wire { - row: self.row, - column, - }; - - let get_local_wire = |column| witness.get_wire(local_wire(column)); - - let get_local_ext = |wire_range: Range| { - debug_assert_eq!(wire_range.len(), D); - let values = wire_range.map(get_local_wire).collect::>(); - let arr = values.try_into().unwrap(); - F::Extension::from_basefield_array(arr) - }; - - // Compute the new vector and the values for equality_dummy and insert_here - let vec_size = self.gate.vec_size; - let orig_vec = (0..vec_size) - .map(|i| get_local_ext(self.gate.wires_original_list_item(i))) - .collect::>(); - let to_insert = get_local_ext(self.gate.wires_element_to_insert()); - let insertion_index_f = get_local_wire(self.gate.wires_insertion_index()); - - let insertion_index = insertion_index_f.to_canonical_u64() as usize; - debug_assert!( - insertion_index <= vec_size, - "Insertion index {} is larger than the vector size {}", - insertion_index, - vec_size - ); - - let mut new_vec = orig_vec; - new_vec.insert(insertion_index, to_insert); - - let mut equality_dummy_vals = Vec::new(); - for i in 0..=vec_size { - equality_dummy_vals.push(if i == insertion_index { - F::ONE - } else { - (F::from_canonical_usize(i) - insertion_index_f).inverse() - }); - } - - let mut insert_here_vals = vec![F::ZERO; vec_size]; - insert_here_vals.insert(insertion_index, F::ONE); - - for i in 0..=vec_size { - let output_wires = self.gate.wires_output_list_item(i).map(local_wire); - out_buffer.set_ext_wires(output_wires, new_vec[i]); - let equality_dummy_wire = local_wire(self.gate.wire_equality_dummy_for_round_r(i)); - out_buffer.set_wire(equality_dummy_wire, equality_dummy_vals[i]); - let insert_here_wire = local_wire(self.gate.wire_insert_here_for_round_r(i)); - out_buffer.set_wire(insert_here_wire, insert_here_vals[i]); - } - } -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - use plonky2::field::goldilocks_field::GoldilocksField; - use plonky2::field::types::Sample; - use plonky2::gates::gate_testing::{test_eval_fns, test_low_degree}; - use plonky2::hash::hash_types::HashOut; - use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - - use super::*; - - #[test] - fn wire_indices() { - let gate = InsertionGate:: { - vec_size: 3, - _phantom: PhantomData, - }; - - assert_eq!(gate.wires_insertion_index(), 0); - assert_eq!(gate.wires_element_to_insert(), 1..5); - assert_eq!(gate.wires_original_list_item(0), 5..9); - assert_eq!(gate.wires_original_list_item(2), 13..17); - assert_eq!(gate.wires_output_list_item(0), 17..21); - assert_eq!(gate.wires_output_list_item(3), 29..33); - assert_eq!(gate.wire_equality_dummy_for_round_r(0), 33); - assert_eq!(gate.wire_equality_dummy_for_round_r(3), 36); - assert_eq!(gate.wire_insert_here_for_round_r(0), 37); - assert_eq!(gate.wire_insert_here_for_round_r(3), 40); - } - - #[test] - fn low_degree() { - test_low_degree::(InsertionGate::new(4)); - } - - #[test] - fn eval_fns() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - test_eval_fns::(InsertionGate::new(4)) - } - - #[test] - fn test_gate_constraint() { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - type FF = >::FE; - - /// Returns the local wires for an insertion gate given the original vector, element to - /// insert, and index. - fn get_wires(orig_vec: Vec, insertion_index: usize, element_to_insert: FF) -> Vec { - let vec_size = orig_vec.len(); - - let mut v = vec![F::from_canonical_usize(insertion_index)]; - v.extend(element_to_insert.0); - for j in 0..vec_size { - v.extend(orig_vec[j].0); - } - - let mut new_vec = orig_vec; - new_vec.insert(insertion_index, element_to_insert); - let mut equality_dummy_vals = Vec::new(); - for i in 0..=vec_size { - equality_dummy_vals.push(if i == insertion_index { - F::ONE - } else { - (F::from_canonical_usize(i) - F::from_canonical_usize(insertion_index)) - .inverse() - }); - } - let mut insert_here_vals = vec![F::ZERO; vec_size]; - insert_here_vals.insert(insertion_index, F::ONE); - - for j in 0..=vec_size { - v.extend(new_vec[j].0); - } - v.extend(equality_dummy_vals); - v.extend(insert_here_vals); - - v.iter().map(|&x| x.into()).collect() - } - - let orig_vec = vec![FF::rand(); 3]; - let insertion_index = 1; - let element_to_insert = FF::rand(); - let gate = InsertionGate:: { - vec_size: 3, - _phantom: PhantomData, - }; - let vars = EvaluationVars { - local_constants: &[], - local_wires: &get_wires(orig_vec, insertion_index, element_to_insert), - public_inputs_hash: &HashOut::rand(), - }; - - assert!( - gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()), - "Gate constraints are not satisfied." - ); - } -} diff --git a/insertion/src/lib.rs b/insertion/src/lib.rs deleted file mode 100644 index e71919dd..00000000 --- a/insertion/src/lib.rs +++ /dev/null @@ -1,12 +0,0 @@ -#![allow(clippy::new_without_default)] -#![allow(clippy::too_many_arguments)] -#![allow(clippy::type_complexity)] -#![allow(clippy::len_without_is_empty)] -#![allow(clippy::needless_range_loop)] -#![allow(clippy::return_self_not_must_use)] -#![no_std] - -extern crate alloc; - -pub mod insert_gadget; -pub mod insertion_gate; diff --git a/u32/Cargo.toml b/u32/Cargo.toml deleted file mode 100644 index bde1b534..00000000 --- a/u32/Cargo.toml +++ /dev/null @@ -1,17 +0,0 @@ -[package] -name = "plonky2_u32" -description = "u32 gadget for Plonky2" -version = "0.1.0" -license = "MIT OR Apache-2.0" -repository = "https://github.com/mir-protocol/plonky2" -edition = "2021" - -[dependencies] -anyhow = { version = "1.0.40", default-features = false } -itertools = { version = "0.10.0", default-features = false } -num = { version = "0.4", default-features = false } -plonky2 = { version = "0.1.2", default-features = false } - -[dev-dependencies] -plonky2 = { version = "0.1.2", default-features = false, features = ["gate_testing"] } -rand = { version = "0.8.4", default-features = false, features = ["getrandom"] } diff --git a/u32/LICENSE-APACHE b/u32/LICENSE-APACHE deleted file mode 100644 index 1e5006dc..00000000 --- a/u32/LICENSE-APACHE +++ /dev/null @@ -1,202 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - -TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - -1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - -2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - -3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - -4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - -5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - -6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - -7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - -8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - -9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - -END OF TERMS AND CONDITIONS - -APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - -Copyright [yyyy] [name of copyright owner] - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - diff --git a/u32/LICENSE-MIT b/u32/LICENSE-MIT deleted file mode 100644 index 86d690b2..00000000 --- a/u32/LICENSE-MIT +++ /dev/null @@ -1,21 +0,0 @@ -The MIT License (MIT) - -Copyright (c) 2022 The Plonky2 Authors - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. diff --git a/u32/README.md b/u32/README.md deleted file mode 100644 index bb4e2d8a..00000000 --- a/u32/README.md +++ /dev/null @@ -1,13 +0,0 @@ -## License - -Licensed under either of - -* Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) -* MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) - -at your option. - - -### Contribution - -Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in the work by you, as defined in the Apache-2.0 license, shall be dual licensed as above, without any additional terms or conditions. diff --git a/u32/src/gadgets/arithmetic_u32.rs b/u32/src/gadgets/arithmetic_u32.rs deleted file mode 100644 index 65f5ac07..00000000 --- a/u32/src/gadgets/arithmetic_u32.rs +++ /dev/null @@ -1,303 +0,0 @@ -use alloc::vec; -use alloc::vec::Vec; -use core::marker::PhantomData; - -use plonky2::field::extension::Extendable; -use plonky2::hash::hash_types::RichField; -use plonky2::iop::generator::{GeneratedValues, SimpleGenerator}; -use plonky2::iop::target::Target; -use plonky2::iop::witness::{PartitionWitness, Witness}; -use plonky2::plonk::circuit_builder::CircuitBuilder; - -use crate::gates::add_many_u32::U32AddManyGate; -use crate::gates::arithmetic_u32::U32ArithmeticGate; -use crate::gates::subtraction_u32::U32SubtractionGate; -use crate::witness::GeneratedValuesU32; - -#[derive(Clone, Copy, Debug)] -pub struct U32Target(pub Target); - -pub trait CircuitBuilderU32, const D: usize> { - fn add_virtual_u32_target(&mut self) -> U32Target; - - fn add_virtual_u32_targets(&mut self, n: usize) -> Vec; - - /// Returns a U32Target for the value `c`, which is assumed to be at most 32 bits. - fn constant_u32(&mut self, c: u32) -> U32Target; - - fn zero_u32(&mut self) -> U32Target; - - fn one_u32(&mut self) -> U32Target; - - fn connect_u32(&mut self, x: U32Target, y: U32Target); - - fn assert_zero_u32(&mut self, x: U32Target); - - /// Checks for special cases where the value of - /// `x * y + z` - /// can be determined without adding a `U32ArithmeticGate`. - fn arithmetic_u32_special_cases( - &mut self, - x: U32Target, - y: U32Target, - z: U32Target, - ) -> Option<(U32Target, U32Target)>; - - // Returns x * y + z. - fn mul_add_u32(&mut self, x: U32Target, y: U32Target, z: U32Target) -> (U32Target, U32Target); - - fn add_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target); - - fn add_many_u32(&mut self, to_add: &[U32Target]) -> (U32Target, U32Target); - - fn add_u32s_with_carry( - &mut self, - to_add: &[U32Target], - carry: U32Target, - ) -> (U32Target, U32Target); - - fn mul_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target); - - // Returns x - y - borrow, as a pair (result, borrow), where borrow is 0 or 1 depending on whether borrowing from the next digit is required (iff y + borrow > x). - fn sub_u32(&mut self, x: U32Target, y: U32Target, borrow: U32Target) -> (U32Target, U32Target); -} - -impl, const D: usize> CircuitBuilderU32 - for CircuitBuilder -{ - fn add_virtual_u32_target(&mut self) -> U32Target { - U32Target(self.add_virtual_target()) - } - - fn add_virtual_u32_targets(&mut self, n: usize) -> Vec { - self.add_virtual_targets(n) - .into_iter() - .map(U32Target) - .collect() - } - - /// Returns a U32Target for the value `c`, which is assumed to be at most 32 bits. - fn constant_u32(&mut self, c: u32) -> U32Target { - U32Target(self.constant(F::from_canonical_u32(c))) - } - - fn zero_u32(&mut self) -> U32Target { - U32Target(self.zero()) - } - - fn one_u32(&mut self) -> U32Target { - U32Target(self.one()) - } - - fn connect_u32(&mut self, x: U32Target, y: U32Target) { - self.connect(x.0, y.0) - } - - fn assert_zero_u32(&mut self, x: U32Target) { - self.assert_zero(x.0) - } - - /// Checks for special cases where the value of - /// `x * y + z` - /// can be determined without adding a `U32ArithmeticGate`. - fn arithmetic_u32_special_cases( - &mut self, - x: U32Target, - y: U32Target, - z: U32Target, - ) -> Option<(U32Target, U32Target)> { - let x_const = self.target_as_constant(x.0); - let y_const = self.target_as_constant(y.0); - let z_const = self.target_as_constant(z.0); - - // If both terms are constant, return their (constant) sum. - let first_term_const = if let (Some(xx), Some(yy)) = (x_const, y_const) { - Some(xx * yy) - } else { - None - }; - - if let (Some(a), Some(b)) = (first_term_const, z_const) { - let sum = (a + b).to_canonical_u64(); - let (low, high) = (sum as u32, (sum >> 32) as u32); - return Some((self.constant_u32(low), self.constant_u32(high))); - } - - None - } - - // Returns x * y + z. - fn mul_add_u32(&mut self, x: U32Target, y: U32Target, z: U32Target) -> (U32Target, U32Target) { - if let Some(result) = self.arithmetic_u32_special_cases(x, y, z) { - return result; - } - - let gate = U32ArithmeticGate::::new_from_config(&self.config); - let (row, copy) = self.find_slot(gate, &[], &[]); - - self.connect(Target::wire(row, gate.wire_ith_multiplicand_0(copy)), x.0); - self.connect(Target::wire(row, gate.wire_ith_multiplicand_1(copy)), y.0); - self.connect(Target::wire(row, gate.wire_ith_addend(copy)), z.0); - - let output_low = U32Target(Target::wire(row, gate.wire_ith_output_low_half(copy))); - let output_high = U32Target(Target::wire(row, gate.wire_ith_output_high_half(copy))); - - (output_low, output_high) - } - - fn add_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target) { - let one = self.one_u32(); - self.mul_add_u32(a, one, b) - } - - fn add_many_u32(&mut self, to_add: &[U32Target]) -> (U32Target, U32Target) { - match to_add.len() { - 0 => (self.zero_u32(), self.zero_u32()), - 1 => (to_add[0], self.zero_u32()), - 2 => self.add_u32(to_add[0], to_add[1]), - _ => { - let num_addends = to_add.len(); - let gate = U32AddManyGate::::new_from_config(&self.config, num_addends); - let (row, copy) = - self.find_slot(gate, &[F::from_canonical_usize(num_addends)], &[]); - - for j in 0..num_addends { - self.connect( - Target::wire(row, gate.wire_ith_op_jth_addend(copy, j)), - to_add[j].0, - ); - } - let zero = self.zero(); - self.connect(Target::wire(row, gate.wire_ith_carry(copy)), zero); - - let output_low = U32Target(Target::wire(row, gate.wire_ith_output_result(copy))); - let output_high = U32Target(Target::wire(row, gate.wire_ith_output_carry(copy))); - - (output_low, output_high) - } - } - } - - fn add_u32s_with_carry( - &mut self, - to_add: &[U32Target], - carry: U32Target, - ) -> (U32Target, U32Target) { - if to_add.len() == 1 { - return self.add_u32(to_add[0], carry); - } - - let num_addends = to_add.len(); - - let gate = U32AddManyGate::::new_from_config(&self.config, num_addends); - let (row, copy) = self.find_slot(gate, &[F::from_canonical_usize(num_addends)], &[]); - - for j in 0..num_addends { - self.connect( - Target::wire(row, gate.wire_ith_op_jth_addend(copy, j)), - to_add[j].0, - ); - } - self.connect(Target::wire(row, gate.wire_ith_carry(copy)), carry.0); - - let output = U32Target(Target::wire(row, gate.wire_ith_output_result(copy))); - let output_carry = U32Target(Target::wire(row, gate.wire_ith_output_carry(copy))); - - (output, output_carry) - } - - fn mul_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target) { - let zero = self.zero_u32(); - self.mul_add_u32(a, b, zero) - } - - // Returns x - y - borrow, as a pair (result, borrow), where borrow is 0 or 1 depending on whether borrowing from the next digit is required (iff y + borrow > x). - fn sub_u32(&mut self, x: U32Target, y: U32Target, borrow: U32Target) -> (U32Target, U32Target) { - let gate = U32SubtractionGate::::new_from_config(&self.config); - let (row, copy) = self.find_slot(gate, &[], &[]); - - self.connect(Target::wire(row, gate.wire_ith_input_x(copy)), x.0); - self.connect(Target::wire(row, gate.wire_ith_input_y(copy)), y.0); - self.connect( - Target::wire(row, gate.wire_ith_input_borrow(copy)), - borrow.0, - ); - - let output_result = U32Target(Target::wire(row, gate.wire_ith_output_result(copy))); - let output_borrow = U32Target(Target::wire(row, gate.wire_ith_output_borrow(copy))); - - (output_result, output_borrow) - } -} - -#[derive(Debug)] -struct SplitToU32Generator, const D: usize> { - x: Target, - low: U32Target, - high: U32Target, - _phantom: PhantomData, -} - -impl, const D: usize> SimpleGenerator - for SplitToU32Generator -{ - fn dependencies(&self) -> Vec { - vec![self.x] - } - - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let x = witness.get_target(self.x); - let x_u64 = x.to_canonical_u64(); - let low = x_u64 as u32; - let high = (x_u64 >> 32) as u32; - - out_buffer.set_u32_target(self.low, low); - out_buffer.set_u32_target(self.high, high); - } -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - use plonky2::iop::witness::PartialWitness; - use plonky2::plonk::circuit_data::CircuitConfig; - use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use rand::rngs::OsRng; - use rand::Rng; - - use super::*; - - #[test] - pub fn test_add_many_u32s() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - const NUM_ADDENDS: usize = 15; - - let config = CircuitConfig::standard_recursion_config(); - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let mut rng = OsRng; - let mut to_add = Vec::new(); - let mut sum = 0u64; - for _ in 0..NUM_ADDENDS { - let x: u32 = rng.gen(); - sum += x as u64; - to_add.push(builder.constant_u32(x)); - } - let carry = builder.zero_u32(); - let (result_low, result_high) = builder.add_u32s_with_carry(&to_add, carry); - let expected_low = builder.constant_u32((sum % (1 << 32)) as u32); - let expected_high = builder.constant_u32((sum >> 32) as u32); - - builder.connect_u32(result_low, expected_low); - builder.connect_u32(result_high, expected_high); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - data.verify(proof) - } -} diff --git a/u32/src/gadgets/mod.rs b/u32/src/gadgets/mod.rs deleted file mode 100644 index 622242ea..00000000 --- a/u32/src/gadgets/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod arithmetic_u32; -pub mod multiple_comparison; -pub mod range_check; diff --git a/u32/src/gadgets/multiple_comparison.rs b/u32/src/gadgets/multiple_comparison.rs deleted file mode 100644 index 8d82c296..00000000 --- a/u32/src/gadgets/multiple_comparison.rs +++ /dev/null @@ -1,152 +0,0 @@ -use alloc::vec; -use alloc::vec::Vec; - -use plonky2::field::extension::Extendable; -use plonky2::hash::hash_types::RichField; -use plonky2::iop::target::{BoolTarget, Target}; -use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2::util::ceil_div_usize; - -use crate::gadgets::arithmetic_u32::U32Target; -use crate::gates::comparison::ComparisonGate; - -/// Returns true if a is less than or equal to b, considered as base-`2^num_bits` limbs of a large value. -/// This range-checks its inputs. -pub fn list_le_circuit, const D: usize>( - builder: &mut CircuitBuilder, - a: Vec, - b: Vec, - num_bits: usize, -) -> BoolTarget { - assert_eq!( - a.len(), - b.len(), - "Comparison must be between same number of inputs and outputs" - ); - let n = a.len(); - - let chunk_bits = 2; - let num_chunks = ceil_div_usize(num_bits, chunk_bits); - - let one = builder.one(); - let mut result = one; - for i in 0..n { - let a_le_b_gate = ComparisonGate::new(num_bits, num_chunks); - let a_le_b_row = builder.add_gate(a_le_b_gate.clone(), vec![]); - builder.connect( - Target::wire(a_le_b_row, a_le_b_gate.wire_first_input()), - a[i], - ); - builder.connect( - Target::wire(a_le_b_row, a_le_b_gate.wire_second_input()), - b[i], - ); - let a_le_b_result = Target::wire(a_le_b_row, a_le_b_gate.wire_result_bool()); - - let b_le_a_gate = ComparisonGate::new(num_bits, num_chunks); - let b_le_a_row = builder.add_gate(b_le_a_gate.clone(), vec![]); - builder.connect( - Target::wire(b_le_a_row, b_le_a_gate.wire_first_input()), - b[i], - ); - builder.connect( - Target::wire(b_le_a_row, b_le_a_gate.wire_second_input()), - a[i], - ); - let b_le_a_result = Target::wire(b_le_a_row, b_le_a_gate.wire_result_bool()); - - let these_limbs_equal = builder.mul(a_le_b_result, b_le_a_result); - let these_limbs_less_than = builder.sub(one, b_le_a_result); - result = builder.mul_add(these_limbs_equal, result, these_limbs_less_than); - } - - // `result` being boolean is an invariant, maintained because its new value is always - // `x * result + y`, where `x` and `y` are booleans that are not simultaneously true. - BoolTarget::new_unsafe(result) -} - -/// Helper function for comparing, specifically, lists of `U32Target`s. -pub fn list_le_u32_circuit, const D: usize>( - builder: &mut CircuitBuilder, - a: Vec, - b: Vec, -) -> BoolTarget { - let a_targets: Vec = a.iter().map(|&t| t.0).collect(); - let b_targets: Vec = b.iter().map(|&t| t.0).collect(); - - list_le_circuit(builder, a_targets, b_targets, 32) -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - use num::BigUint; - use plonky2::field::types::Field; - use plonky2::iop::witness::PartialWitness; - use plonky2::plonk::circuit_data::CircuitConfig; - use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use rand::rngs::OsRng; - use rand::Rng; - - use super::*; - - fn test_list_le(size: usize, num_bits: usize) -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - let config = CircuitConfig::standard_recursion_config(); - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let mut rng = OsRng; - - let lst1: Vec = (0..size) - .map(|_| rng.gen_range(0..(1 << num_bits))) - .collect(); - let lst2: Vec = (0..size) - .map(|_| rng.gen_range(0..(1 << num_bits))) - .collect(); - - let a_biguint = BigUint::from_slice( - &lst1 - .iter() - .flat_map(|&x| [x as u32, (x >> 32) as u32]) - .collect::>(), - ); - let b_biguint = BigUint::from_slice( - &lst2 - .iter() - .flat_map(|&x| [x as u32, (x >> 32) as u32]) - .collect::>(), - ); - - let a = lst1 - .iter() - .map(|&x| builder.constant(F::from_canonical_u64(x))) - .collect(); - let b = lst2 - .iter() - .map(|&x| builder.constant(F::from_canonical_u64(x))) - .collect(); - - let result = list_le_circuit(&mut builder, a, b, num_bits); - - let expected_result = builder.constant_bool(a_biguint <= b_biguint); - builder.connect(result.target, expected_result.target); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - data.verify(proof) - } - - #[test] - fn test_multiple_comparison() -> Result<()> { - for size in [1, 3, 6] { - for num_bits in [20, 32, 40, 44] { - test_list_le(size, num_bits).unwrap(); - } - } - - Ok(()) - } -} diff --git a/u32/src/gadgets/range_check.rs b/u32/src/gadgets/range_check.rs deleted file mode 100644 index 9e8cf2ad..00000000 --- a/u32/src/gadgets/range_check.rs +++ /dev/null @@ -1,23 +0,0 @@ -use alloc::vec; -use alloc::vec::Vec; - -use plonky2::field::extension::Extendable; -use plonky2::hash::hash_types::RichField; -use plonky2::iop::target::Target; -use plonky2::plonk::circuit_builder::CircuitBuilder; - -use crate::gadgets::arithmetic_u32::U32Target; -use crate::gates::range_check_u32::U32RangeCheckGate; - -pub fn range_check_u32_circuit, const D: usize>( - builder: &mut CircuitBuilder, - vals: Vec, -) { - let num_input_limbs = vals.len(); - let gate = U32RangeCheckGate::::new(num_input_limbs); - let row = builder.add_gate(gate, vec![]); - - for i in 0..num_input_limbs { - builder.connect(Target::wire(row, gate.wire_ith_input_limb(i)), vals[i].0); - } -} diff --git a/u32/src/gates/add_many_u32.rs b/u32/src/gates/add_many_u32.rs deleted file mode 100644 index 566a7827..00000000 --- a/u32/src/gates/add_many_u32.rs +++ /dev/null @@ -1,456 +0,0 @@ -use alloc::boxed::Box; -use alloc::format; -use alloc::string::String; -use alloc::vec::Vec; -use core::marker::PhantomData; - -use itertools::unfold; -use plonky2::field::extension::Extendable; -use plonky2::field::types::Field; -use plonky2::gates::gate::Gate; -use plonky2::gates::util::StridedConstraintConsumer; -use plonky2::hash::hash_types::RichField; -use plonky2::iop::ext_target::ExtensionTarget; -use plonky2::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; -use plonky2::iop::target::Target; -use plonky2::iop::wire::Wire; -use plonky2::iop::witness::{PartitionWitness, Witness, WitnessWrite}; -use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2::plonk::circuit_data::CircuitConfig; -use plonky2::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; -use plonky2::util::ceil_div_usize; - -const LOG2_MAX_NUM_ADDENDS: usize = 4; -const MAX_NUM_ADDENDS: usize = 16; - -/// A gate to perform addition on `num_addends` different 32-bit values, plus a small carry -#[derive(Copy, Clone, Debug)] -pub struct U32AddManyGate, const D: usize> { - pub num_addends: usize, - pub num_ops: usize, - _phantom: PhantomData, -} - -impl, const D: usize> U32AddManyGate { - pub fn new_from_config(config: &CircuitConfig, num_addends: usize) -> Self { - Self { - num_addends, - num_ops: Self::num_ops(num_addends, config), - _phantom: PhantomData, - } - } - - pub(crate) fn num_ops(num_addends: usize, config: &CircuitConfig) -> usize { - debug_assert!(num_addends <= MAX_NUM_ADDENDS); - let wires_per_op = (num_addends + 3) + Self::num_limbs(); - let routed_wires_per_op = num_addends + 3; - (config.num_wires / wires_per_op).min(config.num_routed_wires / routed_wires_per_op) - } - - pub fn wire_ith_op_jth_addend(&self, i: usize, j: usize) -> usize { - debug_assert!(i < self.num_ops); - debug_assert!(j < self.num_addends); - (self.num_addends + 3) * i + j - } - pub fn wire_ith_carry(&self, i: usize) -> usize { - debug_assert!(i < self.num_ops); - (self.num_addends + 3) * i + self.num_addends - } - - pub fn wire_ith_output_result(&self, i: usize) -> usize { - debug_assert!(i < self.num_ops); - (self.num_addends + 3) * i + self.num_addends + 1 - } - pub fn wire_ith_output_carry(&self, i: usize) -> usize { - debug_assert!(i < self.num_ops); - (self.num_addends + 3) * i + self.num_addends + 2 - } - - pub fn limb_bits() -> usize { - 2 - } - pub fn num_result_limbs() -> usize { - ceil_div_usize(32, Self::limb_bits()) - } - pub fn num_carry_limbs() -> usize { - ceil_div_usize(LOG2_MAX_NUM_ADDENDS, Self::limb_bits()) - } - pub fn num_limbs() -> usize { - Self::num_result_limbs() + Self::num_carry_limbs() - } - - pub fn wire_ith_output_jth_limb(&self, i: usize, j: usize) -> usize { - debug_assert!(i < self.num_ops); - debug_assert!(j < Self::num_limbs()); - (self.num_addends + 3) * self.num_ops + Self::num_limbs() * i + j - } -} - -impl, const D: usize> Gate for U32AddManyGate { - fn id(&self) -> String { - format!("{self:?}") - } - - fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { - let mut constraints = Vec::with_capacity(self.num_constraints()); - for i in 0..self.num_ops { - let addends: Vec = (0..self.num_addends) - .map(|j| vars.local_wires[self.wire_ith_op_jth_addend(i, j)]) - .collect(); - let carry = vars.local_wires[self.wire_ith_carry(i)]; - - let computed_output = addends.iter().fold(F::Extension::ZERO, |x, &y| x + y) + carry; - - let output_result = vars.local_wires[self.wire_ith_output_result(i)]; - let output_carry = vars.local_wires[self.wire_ith_output_carry(i)]; - - let base = F::Extension::from_canonical_u64(1 << 32u64); - let combined_output = output_carry * base + output_result; - - constraints.push(combined_output - computed_output); - - let mut combined_result_limbs = F::Extension::ZERO; - let mut combined_carry_limbs = F::Extension::ZERO; - let base = F::Extension::from_canonical_u64(1u64 << Self::limb_bits()); - for j in (0..Self::num_limbs()).rev() { - let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; - let max_limb = 1 << Self::limb_bits(); - let product = (0..max_limb) - .map(|x| this_limb - F::Extension::from_canonical_usize(x)) - .product(); - constraints.push(product); - - if j < Self::num_result_limbs() { - combined_result_limbs = base * combined_result_limbs + this_limb; - } else { - combined_carry_limbs = base * combined_carry_limbs + this_limb; - } - } - constraints.push(combined_result_limbs - output_result); - constraints.push(combined_carry_limbs - output_carry); - } - - constraints - } - - fn eval_unfiltered_base_one( - &self, - vars: EvaluationVarsBase, - mut yield_constr: StridedConstraintConsumer, - ) { - for i in 0..self.num_ops { - let addends: Vec = (0..self.num_addends) - .map(|j| vars.local_wires[self.wire_ith_op_jth_addend(i, j)]) - .collect(); - let carry = vars.local_wires[self.wire_ith_carry(i)]; - - let computed_output = addends.iter().fold(F::ZERO, |x, &y| x + y) + carry; - - let output_result = vars.local_wires[self.wire_ith_output_result(i)]; - let output_carry = vars.local_wires[self.wire_ith_output_carry(i)]; - - let base = F::from_canonical_u64(1 << 32u64); - let combined_output = output_carry * base + output_result; - - yield_constr.one(combined_output - computed_output); - - let mut combined_result_limbs = F::ZERO; - let mut combined_carry_limbs = F::ZERO; - let base = F::from_canonical_u64(1u64 << Self::limb_bits()); - for j in (0..Self::num_limbs()).rev() { - let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; - let max_limb = 1 << Self::limb_bits(); - let product = (0..max_limb) - .map(|x| this_limb - F::from_canonical_usize(x)) - .product(); - yield_constr.one(product); - - if j < Self::num_result_limbs() { - combined_result_limbs = base * combined_result_limbs + this_limb; - } else { - combined_carry_limbs = base * combined_carry_limbs + this_limb; - } - } - yield_constr.one(combined_result_limbs - output_result); - yield_constr.one(combined_carry_limbs - output_carry); - } - } - - fn eval_unfiltered_circuit( - &self, - builder: &mut CircuitBuilder, - vars: EvaluationTargets, - ) -> Vec> { - let mut constraints = Vec::with_capacity(self.num_constraints()); - - for i in 0..self.num_ops { - let addends: Vec> = (0..self.num_addends) - .map(|j| vars.local_wires[self.wire_ith_op_jth_addend(i, j)]) - .collect(); - let carry = vars.local_wires[self.wire_ith_carry(i)]; - - let mut computed_output = carry; - for addend in addends { - computed_output = builder.add_extension(computed_output, addend); - } - - let output_result = vars.local_wires[self.wire_ith_output_result(i)]; - let output_carry = vars.local_wires[self.wire_ith_output_carry(i)]; - - let base: F::Extension = F::from_canonical_u64(1 << 32u64).into(); - let base_target = builder.constant_extension(base); - let combined_output = - builder.mul_add_extension(output_carry, base_target, output_result); - - constraints.push(builder.sub_extension(combined_output, computed_output)); - - let mut combined_result_limbs = builder.zero_extension(); - let mut combined_carry_limbs = builder.zero_extension(); - let base = builder - .constant_extension(F::Extension::from_canonical_u64(1u64 << Self::limb_bits())); - for j in (0..Self::num_limbs()).rev() { - let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; - let max_limb = 1 << Self::limb_bits(); - - let mut product = builder.one_extension(); - for x in 0..max_limb { - let x_target = - builder.constant_extension(F::Extension::from_canonical_usize(x)); - let diff = builder.sub_extension(this_limb, x_target); - product = builder.mul_extension(product, diff); - } - constraints.push(product); - - if j < Self::num_result_limbs() { - combined_result_limbs = - builder.mul_add_extension(base, combined_result_limbs, this_limb); - } else { - combined_carry_limbs = - builder.mul_add_extension(base, combined_carry_limbs, this_limb); - } - } - constraints.push(builder.sub_extension(combined_result_limbs, output_result)); - constraints.push(builder.sub_extension(combined_carry_limbs, output_carry)); - } - - constraints - } - - fn generators(&self, row: usize, _local_constants: &[F]) -> Vec>> { - (0..self.num_ops) - .map(|i| { - let g: Box> = Box::new( - U32AddManyGenerator { - gate: *self, - row, - i, - _phantom: PhantomData, - } - .adapter(), - ); - g - }) - .collect() - } - - fn num_wires(&self) -> usize { - (self.num_addends + 3) * self.num_ops + Self::num_limbs() * self.num_ops - } - - fn num_constants(&self) -> usize { - 0 - } - - fn degree(&self) -> usize { - 1 << Self::limb_bits() - } - - fn num_constraints(&self) -> usize { - self.num_ops * (3 + Self::num_limbs()) - } -} - -#[derive(Clone, Debug)] -struct U32AddManyGenerator, const D: usize> { - gate: U32AddManyGate, - row: usize, - i: usize, - _phantom: PhantomData, -} - -impl, const D: usize> SimpleGenerator - for U32AddManyGenerator -{ - fn dependencies(&self) -> Vec { - let local_target = |column| Target::wire(self.row, column); - - (0..self.gate.num_addends) - .map(|j| local_target(self.gate.wire_ith_op_jth_addend(self.i, j))) - .chain([local_target(self.gate.wire_ith_carry(self.i))]) - .collect() - } - - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let local_wire = |column| Wire { - row: self.row, - column, - }; - - let get_local_wire = |column| witness.get_wire(local_wire(column)); - - let addends: Vec<_> = (0..self.gate.num_addends) - .map(|j| get_local_wire(self.gate.wire_ith_op_jth_addend(self.i, j))) - .collect(); - let carry = get_local_wire(self.gate.wire_ith_carry(self.i)); - - let output = addends.iter().fold(F::ZERO, |x, &y| x + y) + carry; - let output_u64 = output.to_canonical_u64(); - - let output_carry_u64 = output_u64 >> 32; - let output_result_u64 = output_u64 & ((1 << 32) - 1); - - let output_carry = F::from_canonical_u64(output_carry_u64); - let output_result = F::from_canonical_u64(output_result_u64); - - let output_carry_wire = local_wire(self.gate.wire_ith_output_carry(self.i)); - let output_result_wire = local_wire(self.gate.wire_ith_output_result(self.i)); - - out_buffer.set_wire(output_carry_wire, output_carry); - out_buffer.set_wire(output_result_wire, output_result); - - let num_result_limbs = U32AddManyGate::::num_result_limbs(); - let num_carry_limbs = U32AddManyGate::::num_carry_limbs(); - let limb_base = 1 << U32AddManyGate::::limb_bits(); - - let split_to_limbs = |mut val, num| { - unfold((), move |_| { - let ret = val % limb_base; - val /= limb_base; - Some(ret) - }) - .take(num) - .map(F::from_canonical_u64) - }; - - let result_limbs = split_to_limbs(output_result_u64, num_result_limbs); - let carry_limbs = split_to_limbs(output_carry_u64, num_carry_limbs); - - for (j, limb) in result_limbs.chain(carry_limbs).enumerate() { - let wire = local_wire(self.gate.wire_ith_output_jth_limb(self.i, j)); - out_buffer.set_wire(wire, limb); - } - } -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - use plonky2::field::extension::quartic::QuarticExtension; - use plonky2::field::goldilocks_field::GoldilocksField; - use plonky2::field::types::Sample; - use plonky2::gates::gate_testing::{test_eval_fns, test_low_degree}; - use plonky2::hash::hash_types::HashOut; - use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use rand::rngs::OsRng; - use rand::Rng; - - use super::*; - - #[test] - fn low_degree() { - test_low_degree::(U32AddManyGate:: { - num_addends: 4, - num_ops: 3, - _phantom: PhantomData, - }) - } - - #[test] - fn eval_fns() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - test_eval_fns::(U32AddManyGate:: { - num_addends: 4, - num_ops: 3, - _phantom: PhantomData, - }) - } - - #[test] - fn test_gate_constraint() { - type F = GoldilocksField; - type FF = QuarticExtension; - const D: usize = 4; - const NUM_ADDENDS: usize = 10; - const NUM_U32_ADD_MANY_OPS: usize = 3; - - fn get_wires(addends: Vec>, carries: Vec) -> Vec { - let mut v0 = Vec::new(); - let mut v1 = Vec::new(); - - let num_result_limbs = U32AddManyGate::::num_result_limbs(); - let num_carry_limbs = U32AddManyGate::::num_carry_limbs(); - let limb_base = 1 << U32AddManyGate::::limb_bits(); - for op in 0..NUM_U32_ADD_MANY_OPS { - let adds = &addends[op]; - let ca = carries[op]; - - let output = adds.iter().sum::() + ca; - let output_result = output & ((1 << 32) - 1); - let output_carry = output >> 32; - - let split_to_limbs = |mut val, num| { - unfold((), move |_| { - let ret = val % limb_base; - val /= limb_base; - Some(ret) - }) - .take(num) - .map(F::from_canonical_u64) - }; - - let mut result_limbs: Vec<_> = - split_to_limbs(output_result, num_result_limbs).collect(); - let mut carry_limbs: Vec<_> = - split_to_limbs(output_carry, num_carry_limbs).collect(); - - for a in adds { - v0.push(F::from_canonical_u64(*a)); - } - v0.push(F::from_canonical_u64(ca)); - v0.push(F::from_canonical_u64(output_result)); - v0.push(F::from_canonical_u64(output_carry)); - v1.append(&mut result_limbs); - v1.append(&mut carry_limbs); - } - - v0.iter().chain(v1.iter()).map(|&x| x.into()).collect() - } - - let mut rng = OsRng; - let addends: Vec> = (0..NUM_U32_ADD_MANY_OPS) - .map(|_| (0..NUM_ADDENDS).map(|_| rng.gen::() as u64).collect()) - .collect(); - let carries: Vec<_> = (0..NUM_U32_ADD_MANY_OPS) - .map(|_| rng.gen::() as u64) - .collect(); - - let gate = U32AddManyGate:: { - num_addends: NUM_ADDENDS, - num_ops: NUM_U32_ADD_MANY_OPS, - _phantom: PhantomData, - }; - - let vars = EvaluationVars { - local_constants: &[], - local_wires: &get_wires(addends, carries), - public_inputs_hash: &HashOut::rand(), - }; - - assert!( - gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()), - "Gate constraints are not satisfied." - ); - } -} diff --git a/u32/src/gates/arithmetic_u32.rs b/u32/src/gates/arithmetic_u32.rs deleted file mode 100644 index c65b32a4..00000000 --- a/u32/src/gates/arithmetic_u32.rs +++ /dev/null @@ -1,575 +0,0 @@ -use alloc::boxed::Box; -use alloc::string::String; -use alloc::vec::Vec; -use alloc::{format, vec}; -use core::marker::PhantomData; - -use itertools::unfold; -use plonky2::field::extension::Extendable; -use plonky2::field::packed::PackedField; -use plonky2::field::types::Field; -use plonky2::gates::gate::Gate; -use plonky2::gates::packed_util::PackedEvaluableBase; -use plonky2::gates::util::StridedConstraintConsumer; -use plonky2::hash::hash_types::RichField; -use plonky2::iop::ext_target::ExtensionTarget; -use plonky2::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; -use plonky2::iop::target::Target; -use plonky2::iop::wire::Wire; -use plonky2::iop::witness::{PartitionWitness, Witness, WitnessWrite}; -use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2::plonk::circuit_data::CircuitConfig; -use plonky2::plonk::vars::{ - EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch, - EvaluationVarsBasePacked, -}; - -/// A gate to perform a basic mul-add on 32-bit values (we assume they are range-checked beforehand). -#[derive(Copy, Clone, Debug)] -pub struct U32ArithmeticGate, const D: usize> { - pub num_ops: usize, - _phantom: PhantomData, -} - -impl, const D: usize> U32ArithmeticGate { - pub fn new_from_config(config: &CircuitConfig) -> Self { - Self { - num_ops: Self::num_ops(config), - _phantom: PhantomData, - } - } - - pub(crate) fn num_ops(config: &CircuitConfig) -> usize { - let wires_per_op = Self::routed_wires_per_op() + Self::num_limbs(); - (config.num_wires / wires_per_op).min(config.num_routed_wires / Self::routed_wires_per_op()) - } - - pub fn wire_ith_multiplicand_0(&self, i: usize) -> usize { - debug_assert!(i < self.num_ops); - Self::routed_wires_per_op() * i - } - pub fn wire_ith_multiplicand_1(&self, i: usize) -> usize { - debug_assert!(i < self.num_ops); - Self::routed_wires_per_op() * i + 1 - } - pub fn wire_ith_addend(&self, i: usize) -> usize { - debug_assert!(i < self.num_ops); - Self::routed_wires_per_op() * i + 2 - } - - pub fn wire_ith_output_low_half(&self, i: usize) -> usize { - debug_assert!(i < self.num_ops); - Self::routed_wires_per_op() * i + 3 - } - - pub fn wire_ith_output_high_half(&self, i: usize) -> usize { - debug_assert!(i < self.num_ops); - Self::routed_wires_per_op() * i + 4 - } - - pub fn wire_ith_inverse(&self, i: usize) -> usize { - debug_assert!(i < self.num_ops); - Self::routed_wires_per_op() * i + 5 - } - - pub fn limb_bits() -> usize { - 2 - } - pub fn num_limbs() -> usize { - 64 / Self::limb_bits() - } - pub fn routed_wires_per_op() -> usize { - 6 - } - pub fn wire_ith_output_jth_limb(&self, i: usize, j: usize) -> usize { - debug_assert!(i < self.num_ops); - debug_assert!(j < Self::num_limbs()); - Self::routed_wires_per_op() * self.num_ops + Self::num_limbs() * i + j - } -} - -impl, const D: usize> Gate for U32ArithmeticGate { - fn id(&self) -> String { - format!("{self:?}") - } - - fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { - let mut constraints = Vec::with_capacity(self.num_constraints()); - for i in 0..self.num_ops { - let multiplicand_0 = vars.local_wires[self.wire_ith_multiplicand_0(i)]; - let multiplicand_1 = vars.local_wires[self.wire_ith_multiplicand_1(i)]; - let addend = vars.local_wires[self.wire_ith_addend(i)]; - - let computed_output = multiplicand_0 * multiplicand_1 + addend; - - let output_low = vars.local_wires[self.wire_ith_output_low_half(i)]; - let output_high = vars.local_wires[self.wire_ith_output_high_half(i)]; - let inverse = vars.local_wires[self.wire_ith_inverse(i)]; - - // Check canonicity of combined_output = output_high * 2^32 + output_low - let combined_output = { - let base = F::Extension::from_canonical_u64(1 << 32u64); - let one = F::Extension::ONE; - let u32_max = F::Extension::from_canonical_u32(u32::MAX); - - // This is zero if and only if the high limb is `u32::MAX`. - // u32::MAX - output_high - let diff = u32_max - output_high; - // If this is zero, the diff is invertible, so the high limb is not `u32::MAX`. - // inverse * diff - 1 - let hi_not_max = inverse * diff - one; - // If this is zero, either the high limb is not `u32::MAX`, or the low limb is zero. - // hi_not_max * limb_0_u32 - let hi_not_max_or_lo_zero = hi_not_max * output_low; - - constraints.push(hi_not_max_or_lo_zero); - - output_high * base + output_low - }; - - constraints.push(combined_output - computed_output); - - let mut combined_low_limbs = F::Extension::ZERO; - let mut combined_high_limbs = F::Extension::ZERO; - let midpoint = Self::num_limbs() / 2; - let base = F::Extension::from_canonical_u64(1u64 << Self::limb_bits()); - for j in (0..Self::num_limbs()).rev() { - let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; - let max_limb = 1 << Self::limb_bits(); - let product = (0..max_limb) - .map(|x| this_limb - F::Extension::from_canonical_usize(x)) - .product(); - constraints.push(product); - - if j < midpoint { - combined_low_limbs = base * combined_low_limbs + this_limb; - } else { - combined_high_limbs = base * combined_high_limbs + this_limb; - } - } - constraints.push(combined_low_limbs - output_low); - constraints.push(combined_high_limbs - output_high); - } - - constraints - } - - fn eval_unfiltered_base_one( - &self, - _vars: EvaluationVarsBase, - _yield_constr: StridedConstraintConsumer, - ) { - panic!("use eval_unfiltered_base_packed instead"); - } - - fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch) -> Vec { - self.eval_unfiltered_base_batch_packed(vars_base) - } - - fn eval_unfiltered_circuit( - &self, - builder: &mut CircuitBuilder, - vars: EvaluationTargets, - ) -> Vec> { - let mut constraints = Vec::with_capacity(self.num_constraints()); - - for i in 0..self.num_ops { - let multiplicand_0 = vars.local_wires[self.wire_ith_multiplicand_0(i)]; - let multiplicand_1 = vars.local_wires[self.wire_ith_multiplicand_1(i)]; - let addend = vars.local_wires[self.wire_ith_addend(i)]; - - let computed_output = builder.mul_add_extension(multiplicand_0, multiplicand_1, addend); - - let output_low = vars.local_wires[self.wire_ith_output_low_half(i)]; - let output_high = vars.local_wires[self.wire_ith_output_high_half(i)]; - let inverse = vars.local_wires[self.wire_ith_inverse(i)]; - - // Check canonicity of combined_output = output_high * 2^32 + output_low - let combined_output = { - let base: F::Extension = F::from_canonical_u64(1 << 32u64).into(); - let base_target = builder.constant_extension(base); - let one = builder.one_extension(); - let u32_max = - builder.constant_extension(F::Extension::from_canonical_u32(u32::MAX)); - - // This is zero if and only if the high limb is `u32::MAX`. - let diff = builder.sub_extension(u32_max, output_high); - // If this is zero, the diff is invertible, so the high limb is not `u32::MAX`. - let hi_not_max = builder.mul_sub_extension(inverse, diff, one); - // If this is zero, either the high limb is not `u32::MAX`, or the low limb is zero. - let hi_not_max_or_lo_zero = builder.mul_extension(hi_not_max, output_low); - - constraints.push(hi_not_max_or_lo_zero); - - builder.mul_add_extension(output_high, base_target, output_low) - }; - - constraints.push(builder.sub_extension(combined_output, computed_output)); - - let mut combined_low_limbs = builder.zero_extension(); - let mut combined_high_limbs = builder.zero_extension(); - let midpoint = Self::num_limbs() / 2; - let base = builder - .constant_extension(F::Extension::from_canonical_u64(1u64 << Self::limb_bits())); - for j in (0..Self::num_limbs()).rev() { - let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; - let max_limb = 1 << Self::limb_bits(); - - let mut product = builder.one_extension(); - for x in 0..max_limb { - let x_target = - builder.constant_extension(F::Extension::from_canonical_usize(x)); - let diff = builder.sub_extension(this_limb, x_target); - product = builder.mul_extension(product, diff); - } - constraints.push(product); - - if j < midpoint { - combined_low_limbs = - builder.mul_add_extension(base, combined_low_limbs, this_limb); - } else { - combined_high_limbs = - builder.mul_add_extension(base, combined_high_limbs, this_limb); - } - } - - constraints.push(builder.sub_extension(combined_low_limbs, output_low)); - constraints.push(builder.sub_extension(combined_high_limbs, output_high)); - } - - constraints - } - - fn generators(&self, row: usize, _local_constants: &[F]) -> Vec>> { - (0..self.num_ops) - .map(|i| { - let g: Box> = Box::new( - U32ArithmeticGenerator { - gate: *self, - row, - i, - _phantom: PhantomData, - } - .adapter(), - ); - g - }) - .collect() - } - - fn num_wires(&self) -> usize { - self.num_ops * (Self::routed_wires_per_op() + Self::num_limbs()) - } - - fn num_constants(&self) -> usize { - 0 - } - - fn degree(&self) -> usize { - 1 << Self::limb_bits() - } - - fn num_constraints(&self) -> usize { - self.num_ops * (4 + Self::num_limbs()) - } -} - -impl, const D: usize> PackedEvaluableBase - for U32ArithmeticGate -{ - fn eval_unfiltered_base_packed>( - &self, - vars: EvaluationVarsBasePacked

, - mut yield_constr: StridedConstraintConsumer

, - ) { - for i in 0..self.num_ops { - let multiplicand_0 = vars.local_wires[self.wire_ith_multiplicand_0(i)]; - let multiplicand_1 = vars.local_wires[self.wire_ith_multiplicand_1(i)]; - let addend = vars.local_wires[self.wire_ith_addend(i)]; - - let computed_output = multiplicand_0 * multiplicand_1 + addend; - - let output_low = vars.local_wires[self.wire_ith_output_low_half(i)]; - let output_high = vars.local_wires[self.wire_ith_output_high_half(i)]; - let inverse = vars.local_wires[self.wire_ith_inverse(i)]; - - let combined_output = { - let base = P::from(F::from_canonical_u64(1 << 32u64)); - let one = P::ONES; - let u32_max = P::from(F::from_canonical_u32(u32::MAX)); - - // This is zero if and only if the high limb is `u32::MAX`. - // u32::MAX - output_high - let diff = u32_max - output_high; - // If this is zero, the diff is invertible, so the high limb is not `u32::MAX`. - // inverse * diff - 1 - let hi_not_max = inverse * diff - one; - // If this is zero, either the high limb is not `u32::MAX`, or the low limb is zero. - // hi_not_max * limb_0_u32 - let hi_not_max_or_lo_zero = hi_not_max * output_low; - - yield_constr.one(hi_not_max_or_lo_zero); - - output_high * base + output_low - }; - - yield_constr.one(combined_output - computed_output); - - let mut combined_low_limbs = P::ZEROS; - let mut combined_high_limbs = P::ZEROS; - let midpoint = Self::num_limbs() / 2; - let base = F::from_canonical_u64(1u64 << Self::limb_bits()); - for j in (0..Self::num_limbs()).rev() { - let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; - let max_limb = 1 << Self::limb_bits(); - let product = (0..max_limb) - .map(|x| this_limb - F::from_canonical_usize(x)) - .product(); - yield_constr.one(product); - - if j < midpoint { - combined_low_limbs = combined_low_limbs * base + this_limb; - } else { - combined_high_limbs = combined_high_limbs * base + this_limb; - } - } - yield_constr.one(combined_low_limbs - output_low); - yield_constr.one(combined_high_limbs - output_high); - } - } -} - -#[derive(Clone, Debug)] -struct U32ArithmeticGenerator, const D: usize> { - gate: U32ArithmeticGate, - row: usize, - i: usize, - _phantom: PhantomData, -} - -impl, const D: usize> SimpleGenerator - for U32ArithmeticGenerator -{ - fn dependencies(&self) -> Vec { - let local_target = |column| Target::wire(self.row, column); - - vec![ - local_target(self.gate.wire_ith_multiplicand_0(self.i)), - local_target(self.gate.wire_ith_multiplicand_1(self.i)), - local_target(self.gate.wire_ith_addend(self.i)), - ] - } - - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let local_wire = |column| Wire { - row: self.row, - column, - }; - - let get_local_wire = |column| witness.get_wire(local_wire(column)); - - let multiplicand_0 = get_local_wire(self.gate.wire_ith_multiplicand_0(self.i)); - let multiplicand_1 = get_local_wire(self.gate.wire_ith_multiplicand_1(self.i)); - let addend = get_local_wire(self.gate.wire_ith_addend(self.i)); - - let output = multiplicand_0 * multiplicand_1 + addend; - let mut output_u64 = output.to_canonical_u64(); - - let output_high_u64 = output_u64 >> 32; - let output_low_u64 = output_u64 & ((1 << 32) - 1); - - let output_high = F::from_canonical_u64(output_high_u64); - let output_low = F::from_canonical_u64(output_low_u64); - - let output_high_wire = local_wire(self.gate.wire_ith_output_high_half(self.i)); - let output_low_wire = local_wire(self.gate.wire_ith_output_low_half(self.i)); - - out_buffer.set_wire(output_high_wire, output_high); - out_buffer.set_wire(output_low_wire, output_low); - - let diff = u32::MAX as u64 - output_high_u64; - let inverse = if diff == 0 { - F::ZERO - } else { - F::from_canonical_u64(diff).inverse() - }; - let inverse_wire = local_wire(self.gate.wire_ith_inverse(self.i)); - out_buffer.set_wire(inverse_wire, inverse); - - let num_limbs = U32ArithmeticGate::::num_limbs(); - let limb_base = 1 << U32ArithmeticGate::::limb_bits(); - let output_limbs_u64 = unfold((), move |_| { - let ret = output_u64 % limb_base; - output_u64 /= limb_base; - Some(ret) - }) - .take(num_limbs); - let output_limbs_f = output_limbs_u64.map(F::from_canonical_u64); - - for (j, output_limb) in output_limbs_f.enumerate() { - let wire = local_wire(self.gate.wire_ith_output_jth_limb(self.i, j)); - out_buffer.set_wire(wire, output_limb); - } - } -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - use plonky2::field::goldilocks_field::GoldilocksField; - use plonky2::field::types::Sample; - use plonky2::gates::gate_testing::{test_eval_fns, test_low_degree}; - use plonky2::hash::hash_types::HashOut; - use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use rand::rngs::OsRng; - use rand::Rng; - - use super::*; - - #[test] - fn low_degree() { - test_low_degree::(U32ArithmeticGate:: { - num_ops: 3, - _phantom: PhantomData, - }) - } - - #[test] - fn eval_fns() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - test_eval_fns::(U32ArithmeticGate:: { - num_ops: 3, - _phantom: PhantomData, - }) - } - - fn get_wires< - F: RichField + Extendable, - FF: From, - const D: usize, - const NUM_U32_ARITHMETIC_OPS: usize, - >( - multiplicands_0: Vec, - multiplicands_1: Vec, - addends: Vec, - ) -> Vec { - let mut v0 = Vec::new(); - let mut v1 = Vec::new(); - - let limb_bits = U32ArithmeticGate::::limb_bits(); - let num_limbs = U32ArithmeticGate::::num_limbs(); - let limb_base = 1 << limb_bits; - for c in 0..NUM_U32_ARITHMETIC_OPS { - let m0 = multiplicands_0[c]; - let m1 = multiplicands_1[c]; - let a = addends[c]; - - let mut output = m0 * m1 + a; - let output_low = output & ((1 << 32) - 1); - let output_high = output >> 32; - let diff = u32::MAX as u64 - output_high; - let inverse = if diff == 0 { - F::ZERO - } else { - F::from_canonical_u64(diff).inverse() - }; - - let mut output_limbs = Vec::with_capacity(num_limbs); - for _i in 0..num_limbs { - output_limbs.push(output % limb_base); - output /= limb_base; - } - let mut output_limbs_f: Vec<_> = output_limbs - .into_iter() - .map(F::from_canonical_u64) - .collect(); - - v0.push(F::from_canonical_u64(m0)); - v0.push(F::from_canonical_u64(m1)); - v0.push(F::from_noncanonical_u64(a)); - v0.push(F::from_canonical_u64(output_low)); - v0.push(F::from_canonical_u64(output_high)); - v0.push(inverse); - v1.append(&mut output_limbs_f); - } - - v0.iter().chain(v1.iter()).map(|&x| x.into()).collect() - } - - #[test] - fn test_gate_constraint() { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - type FF = >::FE; - const NUM_U32_ARITHMETIC_OPS: usize = 3; - - let mut rng = OsRng; - let multiplicands_0: Vec<_> = (0..NUM_U32_ARITHMETIC_OPS) - .map(|_| rng.gen::() as u64) - .collect(); - let multiplicands_1: Vec<_> = (0..NUM_U32_ARITHMETIC_OPS) - .map(|_| rng.gen::() as u64) - .collect(); - let addends: Vec<_> = (0..NUM_U32_ARITHMETIC_OPS) - .map(|_| rng.gen::() as u64) - .collect(); - - let gate = U32ArithmeticGate:: { - num_ops: NUM_U32_ARITHMETIC_OPS, - _phantom: PhantomData, - }; - - let vars = EvaluationVars { - local_constants: &[], - local_wires: &get_wires::( - multiplicands_0, - multiplicands_1, - addends, - ), - public_inputs_hash: &HashOut::rand(), - }; - - assert!( - gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()), - "Gate constraints are not satisfied." - ); - } - - #[test] - fn test_canonicity() { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - type FF = >::FE; - const NUM_U32_ARITHMETIC_OPS: usize = 3; - - let multiplicands_0 = vec![0; NUM_U32_ARITHMETIC_OPS]; - let multiplicands_1 = vec![0; NUM_U32_ARITHMETIC_OPS]; - // A non-canonical addend will produce a non-canonical output using - // get_wires. - let addends = vec![0xFFFFFFFF00000001; NUM_U32_ARITHMETIC_OPS]; - - let gate = U32ArithmeticGate:: { - num_ops: NUM_U32_ARITHMETIC_OPS, - _phantom: PhantomData, - }; - - let vars = EvaluationVars { - local_constants: &[], - local_wires: &get_wires::( - multiplicands_0, - multiplicands_1, - addends, - ), - public_inputs_hash: &HashOut::rand(), - }; - - assert!( - !gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()), - "Non-canonical output should not pass constraints." - ); - } -} diff --git a/u32/src/gates/comparison.rs b/u32/src/gates/comparison.rs deleted file mode 100644 index d10f3b80..00000000 --- a/u32/src/gates/comparison.rs +++ /dev/null @@ -1,710 +0,0 @@ -use alloc::boxed::Box; -use alloc::string::String; -use alloc::vec::Vec; -use alloc::{format, vec}; -use core::marker::PhantomData; - -use plonky2::field::extension::Extendable; -use plonky2::field::packed::PackedField; -use plonky2::field::types::{Field, Field64}; -use plonky2::gates::gate::Gate; -use plonky2::gates::packed_util::PackedEvaluableBase; -use plonky2::gates::util::StridedConstraintConsumer; -use plonky2::hash::hash_types::RichField; -use plonky2::iop::ext_target::ExtensionTarget; -use plonky2::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; -use plonky2::iop::target::Target; -use plonky2::iop::wire::Wire; -use plonky2::iop::witness::{PartitionWitness, Witness, WitnessWrite}; -use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2::plonk::plonk_common::{reduce_with_powers, reduce_with_powers_ext_circuit}; -use plonky2::plonk::vars::{ - EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch, - EvaluationVarsBasePacked, -}; -use plonky2::util::{bits_u64, ceil_div_usize}; - -/// A gate for checking that one value is less than or equal to another. -#[derive(Clone, Debug)] -pub struct ComparisonGate, const D: usize> { - pub(crate) num_bits: usize, - pub(crate) num_chunks: usize, - _phantom: PhantomData, -} - -impl, const D: usize> ComparisonGate { - pub fn new(num_bits: usize, num_chunks: usize) -> Self { - debug_assert!(num_bits < bits_u64(F::ORDER)); - Self { - num_bits, - num_chunks, - _phantom: PhantomData, - } - } - - pub fn chunk_bits(&self) -> usize { - ceil_div_usize(self.num_bits, self.num_chunks) - } - - pub fn wire_first_input(&self) -> usize { - 0 - } - - pub fn wire_second_input(&self) -> usize { - 1 - } - - pub fn wire_result_bool(&self) -> usize { - 2 - } - - pub fn wire_most_significant_diff(&self) -> usize { - 3 - } - - pub fn wire_first_chunk_val(&self, chunk: usize) -> usize { - debug_assert!(chunk < self.num_chunks); - 4 + chunk - } - - pub fn wire_second_chunk_val(&self, chunk: usize) -> usize { - debug_assert!(chunk < self.num_chunks); - 4 + self.num_chunks + chunk - } - - pub fn wire_equality_dummy(&self, chunk: usize) -> usize { - debug_assert!(chunk < self.num_chunks); - 4 + 2 * self.num_chunks + chunk - } - - pub fn wire_chunks_equal(&self, chunk: usize) -> usize { - debug_assert!(chunk < self.num_chunks); - 4 + 3 * self.num_chunks + chunk - } - - pub fn wire_intermediate_value(&self, chunk: usize) -> usize { - debug_assert!(chunk < self.num_chunks); - 4 + 4 * self.num_chunks + chunk - } - - /// The `bit_index`th bit of 2^n - 1 + most_significant_diff. - pub fn wire_most_significant_diff_bit(&self, bit_index: usize) -> usize { - 4 + 5 * self.num_chunks + bit_index - } -} - -impl, const D: usize> Gate for ComparisonGate { - fn id(&self) -> String { - format!("{self:?}") - } - - fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { - let mut constraints = Vec::with_capacity(self.num_constraints()); - - let first_input = vars.local_wires[self.wire_first_input()]; - let second_input = vars.local_wires[self.wire_second_input()]; - - // Get chunks and assert that they match - let first_chunks: Vec = (0..self.num_chunks) - .map(|i| vars.local_wires[self.wire_first_chunk_val(i)]) - .collect(); - let second_chunks: Vec = (0..self.num_chunks) - .map(|i| vars.local_wires[self.wire_second_chunk_val(i)]) - .collect(); - - let first_chunks_combined = reduce_with_powers( - &first_chunks, - F::Extension::from_canonical_usize(1 << self.chunk_bits()), - ); - let second_chunks_combined = reduce_with_powers( - &second_chunks, - F::Extension::from_canonical_usize(1 << self.chunk_bits()), - ); - - constraints.push(first_chunks_combined - first_input); - constraints.push(second_chunks_combined - second_input); - - let chunk_size = 1 << self.chunk_bits(); - - let mut most_significant_diff_so_far = F::Extension::ZERO; - - for i in 0..self.num_chunks { - // Range-check the chunks to be less than `chunk_size`. - let first_product: F::Extension = (0..chunk_size) - .map(|x| first_chunks[i] - F::Extension::from_canonical_usize(x)) - .product(); - let second_product: F::Extension = (0..chunk_size) - .map(|x| second_chunks[i] - F::Extension::from_canonical_usize(x)) - .product(); - constraints.push(first_product); - constraints.push(second_product); - - let difference = second_chunks[i] - first_chunks[i]; - let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)]; - let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)]; - - // Two constraints to assert that `chunks_equal` is valid. - constraints.push(difference * equality_dummy - (F::Extension::ONE - chunks_equal)); - constraints.push(chunks_equal * difference); - - // Update `most_significant_diff_so_far`. - let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)]; - constraints.push(intermediate_value - chunks_equal * most_significant_diff_so_far); - most_significant_diff_so_far = - intermediate_value + (F::Extension::ONE - chunks_equal) * difference; - } - - let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; - constraints.push(most_significant_diff - most_significant_diff_so_far); - - let most_significant_diff_bits: Vec = (0..self.chunk_bits() + 1) - .map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)]) - .collect(); - - // Range-check the bits. - for &bit in &most_significant_diff_bits { - constraints.push(bit * (F::Extension::ONE - bit)); - } - - let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::Extension::TWO); - let two_n = F::Extension::from_canonical_u64(1 << self.chunk_bits()); - constraints.push((two_n + most_significant_diff) - bits_combined); - - // Iff first <= second, the top (n + 1st) bit of (2^n + most_significant_diff) will be 1. - let result_bool = vars.local_wires[self.wire_result_bool()]; - constraints.push(result_bool - most_significant_diff_bits[self.chunk_bits()]); - - constraints - } - - fn eval_unfiltered_base_one( - &self, - _vars: EvaluationVarsBase, - _yield_constr: StridedConstraintConsumer, - ) { - panic!("use eval_unfiltered_base_packed instead"); - } - - fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch) -> Vec { - self.eval_unfiltered_base_batch_packed(vars_base) - } - - fn eval_unfiltered_circuit( - &self, - builder: &mut CircuitBuilder, - vars: EvaluationTargets, - ) -> Vec> { - let mut constraints = Vec::with_capacity(self.num_constraints()); - - let first_input = vars.local_wires[self.wire_first_input()]; - let second_input = vars.local_wires[self.wire_second_input()]; - - // Get chunks and assert that they match - let first_chunks: Vec> = (0..self.num_chunks) - .map(|i| vars.local_wires[self.wire_first_chunk_val(i)]) - .collect(); - let second_chunks: Vec> = (0..self.num_chunks) - .map(|i| vars.local_wires[self.wire_second_chunk_val(i)]) - .collect(); - - let chunk_base = builder.constant(F::from_canonical_usize(1 << self.chunk_bits())); - let first_chunks_combined = - reduce_with_powers_ext_circuit(builder, &first_chunks, chunk_base); - let second_chunks_combined = - reduce_with_powers_ext_circuit(builder, &second_chunks, chunk_base); - - constraints.push(builder.sub_extension(first_chunks_combined, first_input)); - constraints.push(builder.sub_extension(second_chunks_combined, second_input)); - - let chunk_size = 1 << self.chunk_bits(); - - let mut most_significant_diff_so_far = builder.zero_extension(); - - let one = builder.one_extension(); - // Find the chosen chunk. - for i in 0..self.num_chunks { - // Range-check the chunks to be less than `chunk_size`. - let mut first_product = one; - let mut second_product = one; - for x in 0..chunk_size { - let x_f = builder.constant_extension(F::Extension::from_canonical_usize(x)); - let first_diff = builder.sub_extension(first_chunks[i], x_f); - let second_diff = builder.sub_extension(second_chunks[i], x_f); - first_product = builder.mul_extension(first_product, first_diff); - second_product = builder.mul_extension(second_product, second_diff); - } - constraints.push(first_product); - constraints.push(second_product); - - let difference = builder.sub_extension(second_chunks[i], first_chunks[i]); - let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)]; - let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)]; - - // Two constraints to assert that `chunks_equal` is valid. - let diff_times_equal = builder.mul_extension(difference, equality_dummy); - let not_equal = builder.sub_extension(one, chunks_equal); - constraints.push(builder.sub_extension(diff_times_equal, not_equal)); - constraints.push(builder.mul_extension(chunks_equal, difference)); - - // Update `most_significant_diff_so_far`. - let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)]; - let old_diff = builder.mul_extension(chunks_equal, most_significant_diff_so_far); - constraints.push(builder.sub_extension(intermediate_value, old_diff)); - - let not_equal = builder.sub_extension(one, chunks_equal); - let new_diff = builder.mul_extension(not_equal, difference); - most_significant_diff_so_far = builder.add_extension(intermediate_value, new_diff); - } - - let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; - constraints - .push(builder.sub_extension(most_significant_diff, most_significant_diff_so_far)); - - let most_significant_diff_bits: Vec> = (0..self.chunk_bits() + 1) - .map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)]) - .collect(); - - // Range-check the bits. - for &this_bit in &most_significant_diff_bits { - let inverse = builder.sub_extension(one, this_bit); - constraints.push(builder.mul_extension(this_bit, inverse)); - } - - let two = builder.two(); - let bits_combined = - reduce_with_powers_ext_circuit(builder, &most_significant_diff_bits, two); - let two_n = - builder.constant_extension(F::Extension::from_canonical_u64(1 << self.chunk_bits())); - let sum = builder.add_extension(two_n, most_significant_diff); - constraints.push(builder.sub_extension(sum, bits_combined)); - - // Iff first <= second, the top (n + 1st) bit of (2^n + most_significant_diff) will be 1. - let result_bool = vars.local_wires[self.wire_result_bool()]; - constraints.push( - builder.sub_extension(result_bool, most_significant_diff_bits[self.chunk_bits()]), - ); - - constraints - } - - fn generators(&self, row: usize, _local_constants: &[F]) -> Vec>> { - let gen = ComparisonGenerator:: { - row, - gate: self.clone(), - }; - vec![Box::new(gen.adapter())] - } - - fn num_wires(&self) -> usize { - 4 + 5 * self.num_chunks + (self.chunk_bits() + 1) - } - - fn num_constants(&self) -> usize { - 0 - } - - fn degree(&self) -> usize { - 1 << self.chunk_bits() - } - - fn num_constraints(&self) -> usize { - 6 + 5 * self.num_chunks + self.chunk_bits() - } -} - -impl, const D: usize> PackedEvaluableBase - for ComparisonGate -{ - fn eval_unfiltered_base_packed>( - &self, - vars: EvaluationVarsBasePacked

, - mut yield_constr: StridedConstraintConsumer

, - ) { - let first_input = vars.local_wires[self.wire_first_input()]; - let second_input = vars.local_wires[self.wire_second_input()]; - - // Get chunks and assert that they match - let first_chunks: Vec<_> = (0..self.num_chunks) - .map(|i| vars.local_wires[self.wire_first_chunk_val(i)]) - .collect(); - let second_chunks: Vec<_> = (0..self.num_chunks) - .map(|i| vars.local_wires[self.wire_second_chunk_val(i)]) - .collect(); - - let first_chunks_combined = reduce_with_powers( - &first_chunks, - F::from_canonical_usize(1 << self.chunk_bits()), - ); - let second_chunks_combined = reduce_with_powers( - &second_chunks, - F::from_canonical_usize(1 << self.chunk_bits()), - ); - - yield_constr.one(first_chunks_combined - first_input); - yield_constr.one(second_chunks_combined - second_input); - - let chunk_size = 1 << self.chunk_bits(); - - let mut most_significant_diff_so_far = P::ZEROS; - - for i in 0..self.num_chunks { - // Range-check the chunks to be less than `chunk_size`. - let first_product: P = (0..chunk_size) - .map(|x| first_chunks[i] - F::from_canonical_usize(x)) - .product(); - let second_product: P = (0..chunk_size) - .map(|x| second_chunks[i] - F::from_canonical_usize(x)) - .product(); - yield_constr.one(first_product); - yield_constr.one(second_product); - - let difference = second_chunks[i] - first_chunks[i]; - let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)]; - let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)]; - - // Two constraints to assert that `chunks_equal` is valid. - yield_constr.one(difference * equality_dummy - (P::ONES - chunks_equal)); - yield_constr.one(chunks_equal * difference); - - // Update `most_significant_diff_so_far`. - let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)]; - yield_constr.one(intermediate_value - chunks_equal * most_significant_diff_so_far); - most_significant_diff_so_far = - intermediate_value + (P::ONES - chunks_equal) * difference; - } - - let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; - yield_constr.one(most_significant_diff - most_significant_diff_so_far); - - let most_significant_diff_bits: Vec<_> = (0..self.chunk_bits() + 1) - .map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)]) - .collect(); - - // Range-check the bits. - for &bit in &most_significant_diff_bits { - yield_constr.one(bit * (P::ONES - bit)); - } - - let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::TWO); - let two_n = F::from_canonical_u64(1 << self.chunk_bits()); - yield_constr.one((most_significant_diff + two_n) - bits_combined); - - // Iff first <= second, the top (n + 1st) bit of (2^n - 1 + most_significant_diff) will be 1. - let result_bool = vars.local_wires[self.wire_result_bool()]; - yield_constr.one(result_bool - most_significant_diff_bits[self.chunk_bits()]); - } -} - -#[derive(Debug)] -struct ComparisonGenerator, const D: usize> { - row: usize, - gate: ComparisonGate, -} - -impl, const D: usize> SimpleGenerator - for ComparisonGenerator -{ - fn dependencies(&self) -> Vec { - let local_target = |column| Target::wire(self.row, column); - - vec![ - local_target(self.gate.wire_first_input()), - local_target(self.gate.wire_second_input()), - ] - } - - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let local_wire = |column| Wire { - row: self.row, - column, - }; - - let get_local_wire = |column| witness.get_wire(local_wire(column)); - - let first_input = get_local_wire(self.gate.wire_first_input()); - let second_input = get_local_wire(self.gate.wire_second_input()); - - let first_input_u64 = first_input.to_canonical_u64(); - let second_input_u64 = second_input.to_canonical_u64(); - - let result = F::from_canonical_usize((first_input_u64 <= second_input_u64) as usize); - - let chunk_size = 1 << self.gate.chunk_bits(); - let first_input_chunks: Vec = (0..self.gate.num_chunks) - .scan(first_input_u64, |acc, _| { - let tmp = *acc % chunk_size; - *acc /= chunk_size; - Some(F::from_canonical_u64(tmp)) - }) - .collect(); - let second_input_chunks: Vec = (0..self.gate.num_chunks) - .scan(second_input_u64, |acc, _| { - let tmp = *acc % chunk_size; - *acc /= chunk_size; - Some(F::from_canonical_u64(tmp)) - }) - .collect(); - - let chunks_equal: Vec = (0..self.gate.num_chunks) - .map(|i| F::from_bool(first_input_chunks[i] == second_input_chunks[i])) - .collect(); - let equality_dummies: Vec = first_input_chunks - .iter() - .zip(second_input_chunks.iter()) - .map(|(&f, &s)| if f == s { F::ONE } else { F::ONE / (s - f) }) - .collect(); - - let mut most_significant_diff_so_far = F::ZERO; - let mut intermediate_values = Vec::new(); - for i in 0..self.gate.num_chunks { - if first_input_chunks[i] != second_input_chunks[i] { - most_significant_diff_so_far = second_input_chunks[i] - first_input_chunks[i]; - intermediate_values.push(F::ZERO); - } else { - intermediate_values.push(most_significant_diff_so_far); - } - } - let most_significant_diff = most_significant_diff_so_far; - - let two_n = F::from_canonical_usize(1 << self.gate.chunk_bits()); - let two_n_plus_msd = (two_n + most_significant_diff).to_canonical_u64(); - - let msd_bits_u64: Vec = (0..self.gate.chunk_bits() + 1) - .scan(two_n_plus_msd, |acc, _| { - let tmp = *acc % 2; - *acc /= 2; - Some(tmp) - }) - .collect(); - let msd_bits: Vec = msd_bits_u64 - .iter() - .map(|x| F::from_canonical_u64(*x)) - .collect(); - - out_buffer.set_wire(local_wire(self.gate.wire_result_bool()), result); - out_buffer.set_wire( - local_wire(self.gate.wire_most_significant_diff()), - most_significant_diff, - ); - for i in 0..self.gate.num_chunks { - out_buffer.set_wire( - local_wire(self.gate.wire_first_chunk_val(i)), - first_input_chunks[i], - ); - out_buffer.set_wire( - local_wire(self.gate.wire_second_chunk_val(i)), - second_input_chunks[i], - ); - out_buffer.set_wire( - local_wire(self.gate.wire_equality_dummy(i)), - equality_dummies[i], - ); - out_buffer.set_wire(local_wire(self.gate.wire_chunks_equal(i)), chunks_equal[i]); - out_buffer.set_wire( - local_wire(self.gate.wire_intermediate_value(i)), - intermediate_values[i], - ); - } - for i in 0..self.gate.chunk_bits() + 1 { - out_buffer.set_wire( - local_wire(self.gate.wire_most_significant_diff_bit(i)), - msd_bits[i], - ); - } - } -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - use plonky2::field::goldilocks_field::GoldilocksField; - use plonky2::field::types::{PrimeField64, Sample}; - use plonky2::gates::gate_testing::{test_eval_fns, test_low_degree}; - use plonky2::hash::hash_types::HashOut; - use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use rand::rngs::OsRng; - use rand::Rng; - - use super::*; - - #[test] - fn wire_indices() { - type CG = ComparisonGate; - let num_bits = 40; - let num_chunks = 5; - - let gate = CG { - num_bits, - num_chunks, - _phantom: PhantomData, - }; - - assert_eq!(gate.wire_first_input(), 0); - assert_eq!(gate.wire_second_input(), 1); - assert_eq!(gate.wire_result_bool(), 2); - assert_eq!(gate.wire_most_significant_diff(), 3); - assert_eq!(gate.wire_first_chunk_val(0), 4); - assert_eq!(gate.wire_first_chunk_val(4), 8); - assert_eq!(gate.wire_second_chunk_val(0), 9); - assert_eq!(gate.wire_second_chunk_val(4), 13); - assert_eq!(gate.wire_equality_dummy(0), 14); - assert_eq!(gate.wire_equality_dummy(4), 18); - assert_eq!(gate.wire_chunks_equal(0), 19); - assert_eq!(gate.wire_chunks_equal(4), 23); - assert_eq!(gate.wire_intermediate_value(0), 24); - assert_eq!(gate.wire_intermediate_value(4), 28); - assert_eq!(gate.wire_most_significant_diff_bit(0), 29); - assert_eq!(gate.wire_most_significant_diff_bit(8), 37); - } - - #[test] - fn low_degree() { - let num_bits = 40; - let num_chunks = 5; - - test_low_degree::(ComparisonGate::<_, 4>::new(num_bits, num_chunks)) - } - - #[test] - fn eval_fns() -> Result<()> { - let num_bits = 40; - let num_chunks = 5; - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - test_eval_fns::(ComparisonGate::<_, 2>::new(num_bits, num_chunks)) - } - - #[test] - fn test_gate_constraint() { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - type FF = >::FE; - - let num_bits = 40; - let num_chunks = 5; - let chunk_bits = num_bits / num_chunks; - - // Returns the local wires for a comparison gate given the two inputs. - let get_wires = |first_input: F, second_input: F| -> Vec { - let mut v = Vec::new(); - - let first_input_u64 = first_input.to_canonical_u64(); - let second_input_u64 = second_input.to_canonical_u64(); - - let result_bool = F::from_bool(first_input_u64 <= second_input_u64); - - let chunk_size = 1 << chunk_bits; - let mut first_input_chunks: Vec = (0..num_chunks) - .scan(first_input_u64, |acc, _| { - let tmp = *acc % chunk_size; - *acc /= chunk_size; - Some(F::from_canonical_u64(tmp)) - }) - .collect(); - let mut second_input_chunks: Vec = (0..num_chunks) - .scan(second_input_u64, |acc, _| { - let tmp = *acc % chunk_size; - *acc /= chunk_size; - Some(F::from_canonical_u64(tmp)) - }) - .collect(); - - let mut chunks_equal: Vec = (0..num_chunks) - .map(|i| F::from_bool(first_input_chunks[i] == second_input_chunks[i])) - .collect(); - let mut equality_dummies: Vec = first_input_chunks - .iter() - .zip(second_input_chunks.iter()) - .map(|(&f, &s)| if f == s { F::ONE } else { F::ONE / (s - f) }) - .collect(); - - let mut most_significant_diff_so_far = F::ZERO; - let mut intermediate_values = Vec::new(); - for i in 0..num_chunks { - if first_input_chunks[i] != second_input_chunks[i] { - most_significant_diff_so_far = second_input_chunks[i] - first_input_chunks[i]; - intermediate_values.push(F::ZERO); - } else { - intermediate_values.push(most_significant_diff_so_far); - } - } - let most_significant_diff = most_significant_diff_so_far; - - let two_n_plus_msd = - (1 << chunk_bits) as u64 + most_significant_diff.to_canonical_u64(); - let mut msd_bits: Vec = (0..chunk_bits + 1) - .scan(two_n_plus_msd, |acc, _| { - let tmp = *acc % 2; - *acc /= 2; - Some(F::from_canonical_u64(tmp)) - }) - .collect(); - - v.push(first_input); - v.push(second_input); - v.push(result_bool); - v.push(most_significant_diff); - v.append(&mut first_input_chunks); - v.append(&mut second_input_chunks); - v.append(&mut equality_dummies); - v.append(&mut chunks_equal); - v.append(&mut intermediate_values); - v.append(&mut msd_bits); - - v.iter().map(|&x| x.into()).collect() - }; - - let mut rng = OsRng; - let max: u64 = 1 << (num_bits - 1); - let first_input_u64 = rng.gen_range(0..max); - let second_input_u64 = { - let mut val = rng.gen_range(0..max); - while val < first_input_u64 { - val = rng.gen_range(0..max); - } - val - }; - - let first_input = F::from_canonical_u64(first_input_u64); - let second_input = F::from_canonical_u64(second_input_u64); - - let less_than_gate = ComparisonGate:: { - num_bits, - num_chunks, - _phantom: PhantomData, - }; - let less_than_vars = EvaluationVars { - local_constants: &[], - local_wires: &get_wires(first_input, second_input), - public_inputs_hash: &HashOut::rand(), - }; - assert!( - less_than_gate - .eval_unfiltered(less_than_vars) - .iter() - .all(|x| x.is_zero()), - "Gate constraints are not satisfied." - ); - - let equal_gate = ComparisonGate:: { - num_bits, - num_chunks, - _phantom: PhantomData, - }; - let equal_vars = EvaluationVars { - local_constants: &[], - local_wires: &get_wires(first_input, first_input), - public_inputs_hash: &HashOut::rand(), - }; - assert!( - equal_gate - .eval_unfiltered(equal_vars) - .iter() - .all(|x| x.is_zero()), - "Gate constraints are not satisfied." - ); - } -} diff --git a/u32/src/gates/mod.rs b/u32/src/gates/mod.rs deleted file mode 100644 index 1880b163..00000000 --- a/u32/src/gates/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -pub mod add_many_u32; -pub mod arithmetic_u32; -pub mod comparison; -pub mod range_check_u32; -pub mod subtraction_u32; diff --git a/u32/src/gates/range_check_u32.rs b/u32/src/gates/range_check_u32.rs deleted file mode 100644 index 55faa6ca..00000000 --- a/u32/src/gates/range_check_u32.rs +++ /dev/null @@ -1,307 +0,0 @@ -use alloc::boxed::Box; -use alloc::string::String; -use alloc::vec::Vec; -use alloc::{format, vec}; -use core::marker::PhantomData; - -use plonky2::field::extension::Extendable; -use plonky2::field::types::Field; -use plonky2::gates::gate::Gate; -use plonky2::gates::util::StridedConstraintConsumer; -use plonky2::hash::hash_types::RichField; -use plonky2::iop::ext_target::ExtensionTarget; -use plonky2::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; -use plonky2::iop::target::Target; -use plonky2::iop::witness::{PartitionWitness, Witness, WitnessWrite}; -use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2::plonk::plonk_common::{reduce_with_powers, reduce_with_powers_ext_circuit}; -use plonky2::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; -use plonky2::util::ceil_div_usize; - -/// A gate which can decompose a number into base B little-endian limbs. -#[derive(Copy, Clone, Debug)] -pub struct U32RangeCheckGate, const D: usize> { - pub num_input_limbs: usize, - _phantom: PhantomData, -} - -impl, const D: usize> U32RangeCheckGate { - pub fn new(num_input_limbs: usize) -> Self { - Self { - num_input_limbs, - _phantom: PhantomData, - } - } - - pub const AUX_LIMB_BITS: usize = 2; - pub const BASE: usize = 1 << Self::AUX_LIMB_BITS; - - fn aux_limbs_per_input_limb(&self) -> usize { - ceil_div_usize(32, Self::AUX_LIMB_BITS) - } - pub fn wire_ith_input_limb(&self, i: usize) -> usize { - debug_assert!(i < self.num_input_limbs); - i - } - pub fn wire_ith_input_limb_jth_aux_limb(&self, i: usize, j: usize) -> usize { - debug_assert!(i < self.num_input_limbs); - debug_assert!(j < self.aux_limbs_per_input_limb()); - self.num_input_limbs + self.aux_limbs_per_input_limb() * i + j - } -} - -impl, const D: usize> Gate for U32RangeCheckGate { - fn id(&self) -> String { - format!("{self:?}") - } - - fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { - let mut constraints = Vec::with_capacity(self.num_constraints()); - - let base = F::Extension::from_canonical_usize(Self::BASE); - for i in 0..self.num_input_limbs { - let input_limb = vars.local_wires[self.wire_ith_input_limb(i)]; - let aux_limbs: Vec<_> = (0..self.aux_limbs_per_input_limb()) - .map(|j| vars.local_wires[self.wire_ith_input_limb_jth_aux_limb(i, j)]) - .collect(); - let computed_sum = reduce_with_powers(&aux_limbs, base); - - constraints.push(computed_sum - input_limb); - for aux_limb in aux_limbs { - constraints.push( - (0..Self::BASE) - .map(|i| aux_limb - F::Extension::from_canonical_usize(i)) - .product(), - ); - } - } - - constraints - } - - fn eval_unfiltered_base_one( - &self, - vars: EvaluationVarsBase, - mut yield_constr: StridedConstraintConsumer, - ) { - let base = F::from_canonical_usize(Self::BASE); - for i in 0..self.num_input_limbs { - let input_limb = vars.local_wires[self.wire_ith_input_limb(i)]; - let aux_limbs: Vec<_> = (0..self.aux_limbs_per_input_limb()) - .map(|j| vars.local_wires[self.wire_ith_input_limb_jth_aux_limb(i, j)]) - .collect(); - let computed_sum = reduce_with_powers(&aux_limbs, base); - - yield_constr.one(computed_sum - input_limb); - for aux_limb in aux_limbs { - yield_constr.one( - (0..Self::BASE) - .map(|i| aux_limb - F::from_canonical_usize(i)) - .product(), - ); - } - } - } - - fn eval_unfiltered_circuit( - &self, - builder: &mut CircuitBuilder, - vars: EvaluationTargets, - ) -> Vec> { - let mut constraints = Vec::with_capacity(self.num_constraints()); - - let base = builder.constant(F::from_canonical_usize(Self::BASE)); - for i in 0..self.num_input_limbs { - let input_limb = vars.local_wires[self.wire_ith_input_limb(i)]; - let aux_limbs: Vec<_> = (0..self.aux_limbs_per_input_limb()) - .map(|j| vars.local_wires[self.wire_ith_input_limb_jth_aux_limb(i, j)]) - .collect(); - let computed_sum = reduce_with_powers_ext_circuit(builder, &aux_limbs, base); - - constraints.push(builder.sub_extension(computed_sum, input_limb)); - for aux_limb in aux_limbs { - constraints.push({ - let mut acc = builder.one_extension(); - (0..Self::BASE).for_each(|i| { - // We update our accumulator as: - // acc' = acc (x - i) - // = acc x + (-i) acc - // Since -i is constant, we can do this in one arithmetic_extension call. - let neg_i = -F::from_canonical_usize(i); - acc = builder.arithmetic_extension(F::ONE, neg_i, acc, aux_limb, acc) - }); - acc - }); - } - } - - constraints - } - - fn generators(&self, row: usize, _local_constants: &[F]) -> Vec>> { - let gen = U32RangeCheckGenerator { gate: *self, row }; - vec![Box::new(gen.adapter())] - } - - fn num_wires(&self) -> usize { - self.num_input_limbs * (1 + self.aux_limbs_per_input_limb()) - } - - fn num_constants(&self) -> usize { - 0 - } - - // Bounded by the range-check (x-0)*(x-1)*...*(x-BASE+1). - fn degree(&self) -> usize { - Self::BASE - } - - // 1 for checking the each sum of aux limbs, plus a range check for each aux limb. - fn num_constraints(&self) -> usize { - self.num_input_limbs * (1 + self.aux_limbs_per_input_limb()) - } -} - -#[derive(Debug)] -pub struct U32RangeCheckGenerator, const D: usize> { - gate: U32RangeCheckGate, - row: usize, -} - -impl, const D: usize> SimpleGenerator - for U32RangeCheckGenerator -{ - fn dependencies(&self) -> Vec { - let num_input_limbs = self.gate.num_input_limbs; - (0..num_input_limbs) - .map(|i| Target::wire(self.row, self.gate.wire_ith_input_limb(i))) - .collect() - } - - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let num_input_limbs = self.gate.num_input_limbs; - for i in 0..num_input_limbs { - let sum_value = witness - .get_target(Target::wire(self.row, self.gate.wire_ith_input_limb(i))) - .to_canonical_u64() as u32; - - let base = U32RangeCheckGate::::BASE as u32; - let limbs = (0..self.gate.aux_limbs_per_input_limb()) - .map(|j| Target::wire(self.row, self.gate.wire_ith_input_limb_jth_aux_limb(i, j))); - let limbs_value = (0..self.gate.aux_limbs_per_input_limb()) - .scan(sum_value, |acc, _| { - let tmp = *acc % base; - *acc /= base; - Some(F::from_canonical_u32(tmp)) - }) - .collect::>(); - - for (b, b_value) in limbs.zip(limbs_value) { - out_buffer.set_target(b, b_value); - } - } - } -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - use itertools::unfold; - use plonky2::field::extension::quartic::QuarticExtension; - use plonky2::field::goldilocks_field::GoldilocksField; - use plonky2::field::types::{Field, Sample}; - use plonky2::gates::gate_testing::{test_eval_fns, test_low_degree}; - use plonky2::hash::hash_types::HashOut; - use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use rand::rngs::OsRng; - use rand::Rng; - - use super::*; - - #[test] - fn low_degree() { - test_low_degree::(U32RangeCheckGate::new(8)) - } - - #[test] - fn eval_fns() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - test_eval_fns::(U32RangeCheckGate::new(8)) - } - - fn test_gate_constraint(input_limbs: Vec) { - type F = GoldilocksField; - type FF = QuarticExtension; - const D: usize = 4; - const AUX_LIMB_BITS: usize = 2; - const BASE: usize = 1 << AUX_LIMB_BITS; - const AUX_LIMBS_PER_INPUT_LIMB: usize = ceil_div_usize(32, AUX_LIMB_BITS); - - fn get_wires(input_limbs: Vec) -> Vec { - let num_input_limbs = input_limbs.len(); - let mut v = Vec::new(); - - for i in 0..num_input_limbs { - let input_limb = input_limbs[i]; - - let split_to_limbs = |mut val, num| { - unfold((), move |_| { - let ret = val % (BASE as u64); - val /= BASE as u64; - Some(ret) - }) - .take(num) - .map(F::from_canonical_u64) - }; - - let mut aux_limbs: Vec<_> = - split_to_limbs(input_limb, AUX_LIMBS_PER_INPUT_LIMB).collect(); - - v.append(&mut aux_limbs); - } - - input_limbs - .iter() - .cloned() - .map(F::from_canonical_u64) - .chain(v.iter().cloned()) - .map(|x| x.into()) - .collect() - } - - let gate = U32RangeCheckGate:: { - num_input_limbs: 8, - _phantom: PhantomData, - }; - - let vars = EvaluationVars { - local_constants: &[], - local_wires: &get_wires(input_limbs), - public_inputs_hash: &HashOut::rand(), - }; - - assert!( - gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()), - "Gate constraints are not satisfied." - ); - } - - #[test] - fn test_gate_constraint_good() { - let mut rng = OsRng; - let input_limbs: Vec<_> = (0..8).map(|_| rng.gen::() as u64).collect(); - - test_gate_constraint(input_limbs); - } - - #[test] - #[should_panic] - fn test_gate_constraint_bad() { - let mut rng = OsRng; - let input_limbs: Vec<_> = (0..8).map(|_| rng.gen()).collect(); - - test_gate_constraint(input_limbs); - } -} diff --git a/u32/src/gates/subtraction_u32.rs b/u32/src/gates/subtraction_u32.rs deleted file mode 100644 index 01f55e09..00000000 --- a/u32/src/gates/subtraction_u32.rs +++ /dev/null @@ -1,445 +0,0 @@ -use alloc::boxed::Box; -use alloc::string::String; -use alloc::vec::Vec; -use alloc::{format, vec}; -use core::marker::PhantomData; - -use plonky2::field::extension::Extendable; -use plonky2::field::packed::PackedField; -use plonky2::field::types::Field; -use plonky2::gates::gate::Gate; -use plonky2::gates::packed_util::PackedEvaluableBase; -use plonky2::gates::util::StridedConstraintConsumer; -use plonky2::hash::hash_types::RichField; -use plonky2::iop::ext_target::ExtensionTarget; -use plonky2::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; -use plonky2::iop::target::Target; -use plonky2::iop::wire::Wire; -use plonky2::iop::witness::{PartitionWitness, Witness, WitnessWrite}; -use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2::plonk::circuit_data::CircuitConfig; -use plonky2::plonk::vars::{ - EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch, - EvaluationVarsBasePacked, -}; - -/// A gate to perform a subtraction on 32-bit limbs: given `x`, `y`, and `borrow`, it returns -/// the result `x - y - borrow` and, if this underflows, a new `borrow`. Inputs are not range-checked. -#[derive(Copy, Clone, Debug)] -pub struct U32SubtractionGate, const D: usize> { - pub num_ops: usize, - _phantom: PhantomData, -} - -impl, const D: usize> U32SubtractionGate { - pub fn new_from_config(config: &CircuitConfig) -> Self { - Self { - num_ops: Self::num_ops(config), - _phantom: PhantomData, - } - } - - pub(crate) fn num_ops(config: &CircuitConfig) -> usize { - let wires_per_op = 5 + Self::num_limbs(); - let routed_wires_per_op = 5; - (config.num_wires / wires_per_op).min(config.num_routed_wires / routed_wires_per_op) - } - - pub fn wire_ith_input_x(&self, i: usize) -> usize { - debug_assert!(i < self.num_ops); - 5 * i - } - pub fn wire_ith_input_y(&self, i: usize) -> usize { - debug_assert!(i < self.num_ops); - 5 * i + 1 - } - pub fn wire_ith_input_borrow(&self, i: usize) -> usize { - debug_assert!(i < self.num_ops); - 5 * i + 2 - } - - pub fn wire_ith_output_result(&self, i: usize) -> usize { - debug_assert!(i < self.num_ops); - 5 * i + 3 - } - pub fn wire_ith_output_borrow(&self, i: usize) -> usize { - debug_assert!(i < self.num_ops); - 5 * i + 4 - } - - pub fn limb_bits() -> usize { - 2 - } - // We have limbs for the 32 bits of `output_result`. - pub fn num_limbs() -> usize { - 32 / Self::limb_bits() - } - - pub fn wire_ith_output_jth_limb(&self, i: usize, j: usize) -> usize { - debug_assert!(i < self.num_ops); - debug_assert!(j < Self::num_limbs()); - 5 * self.num_ops + Self::num_limbs() * i + j - } -} - -impl, const D: usize> Gate for U32SubtractionGate { - fn id(&self) -> String { - format!("{self:?}") - } - - fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { - let mut constraints = Vec::with_capacity(self.num_constraints()); - for i in 0..self.num_ops { - let input_x = vars.local_wires[self.wire_ith_input_x(i)]; - let input_y = vars.local_wires[self.wire_ith_input_y(i)]; - let input_borrow = vars.local_wires[self.wire_ith_input_borrow(i)]; - - let result_initial = input_x - input_y - input_borrow; - let base = F::Extension::from_canonical_u64(1 << 32u64); - - let output_result = vars.local_wires[self.wire_ith_output_result(i)]; - let output_borrow = vars.local_wires[self.wire_ith_output_borrow(i)]; - - constraints.push(output_result - (result_initial + base * output_borrow)); - - // Range-check output_result to be at most 32 bits. - let mut combined_limbs = F::Extension::ZERO; - let limb_base = F::Extension::from_canonical_u64(1u64 << Self::limb_bits()); - for j in (0..Self::num_limbs()).rev() { - let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; - let max_limb = 1 << Self::limb_bits(); - let product = (0..max_limb) - .map(|x| this_limb - F::Extension::from_canonical_usize(x)) - .product(); - constraints.push(product); - - combined_limbs = limb_base * combined_limbs + this_limb; - } - constraints.push(combined_limbs - output_result); - - // Range-check output_borrow to be one bit. - constraints.push(output_borrow * (F::Extension::ONE - output_borrow)); - } - - constraints - } - - fn eval_unfiltered_base_one( - &self, - _vars: EvaluationVarsBase, - _yield_constr: StridedConstraintConsumer, - ) { - panic!("use eval_unfiltered_base_packed instead"); - } - - fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch) -> Vec { - self.eval_unfiltered_base_batch_packed(vars_base) - } - - fn eval_unfiltered_circuit( - &self, - builder: &mut CircuitBuilder, - vars: EvaluationTargets, - ) -> Vec> { - let mut constraints = Vec::with_capacity(self.num_constraints()); - for i in 0..self.num_ops { - let input_x = vars.local_wires[self.wire_ith_input_x(i)]; - let input_y = vars.local_wires[self.wire_ith_input_y(i)]; - let input_borrow = vars.local_wires[self.wire_ith_input_borrow(i)]; - - let diff = builder.sub_extension(input_x, input_y); - let result_initial = builder.sub_extension(diff, input_borrow); - let base = builder.constant_extension(F::Extension::from_canonical_u64(1 << 32u64)); - - let output_result = vars.local_wires[self.wire_ith_output_result(i)]; - let output_borrow = vars.local_wires[self.wire_ith_output_borrow(i)]; - - let computed_output = builder.mul_add_extension(base, output_borrow, result_initial); - constraints.push(builder.sub_extension(output_result, computed_output)); - - // Range-check output_result to be at most 32 bits. - let mut combined_limbs = builder.zero_extension(); - let limb_base = builder - .constant_extension(F::Extension::from_canonical_u64(1u64 << Self::limb_bits())); - for j in (0..Self::num_limbs()).rev() { - let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; - let max_limb = 1 << Self::limb_bits(); - let mut product = builder.one_extension(); - for x in 0..max_limb { - let x_target = - builder.constant_extension(F::Extension::from_canonical_usize(x)); - let diff = builder.sub_extension(this_limb, x_target); - product = builder.mul_extension(product, diff); - } - constraints.push(product); - - combined_limbs = builder.mul_add_extension(limb_base, combined_limbs, this_limb); - } - constraints.push(builder.sub_extension(combined_limbs, output_result)); - - // Range-check output_borrow to be one bit. - let one = builder.one_extension(); - let not_borrow = builder.sub_extension(one, output_borrow); - constraints.push(builder.mul_extension(output_borrow, not_borrow)); - } - - constraints - } - - fn generators(&self, row: usize, _local_constants: &[F]) -> Vec>> { - (0..self.num_ops) - .map(|i| { - let g: Box> = Box::new( - U32SubtractionGenerator { - gate: *self, - row, - i, - _phantom: PhantomData, - } - .adapter(), - ); - g - }) - .collect() - } - - fn num_wires(&self) -> usize { - self.num_ops * (5 + Self::num_limbs()) - } - - fn num_constants(&self) -> usize { - 0 - } - - fn degree(&self) -> usize { - 1 << Self::limb_bits() - } - - fn num_constraints(&self) -> usize { - self.num_ops * (3 + Self::num_limbs()) - } -} - -impl, const D: usize> PackedEvaluableBase - for U32SubtractionGate -{ - fn eval_unfiltered_base_packed>( - &self, - vars: EvaluationVarsBasePacked

, - mut yield_constr: StridedConstraintConsumer

, - ) { - for i in 0..self.num_ops { - let input_x = vars.local_wires[self.wire_ith_input_x(i)]; - let input_y = vars.local_wires[self.wire_ith_input_y(i)]; - let input_borrow = vars.local_wires[self.wire_ith_input_borrow(i)]; - - let result_initial = input_x - input_y - input_borrow; - let base = F::from_canonical_u64(1 << 32u64); - - let output_result = vars.local_wires[self.wire_ith_output_result(i)]; - let output_borrow = vars.local_wires[self.wire_ith_output_borrow(i)]; - - yield_constr.one(output_result - (result_initial + output_borrow * base)); - - // Range-check output_result to be at most 32 bits. - let mut combined_limbs = P::ZEROS; - let limb_base = F::from_canonical_u64(1u64 << Self::limb_bits()); - for j in (0..Self::num_limbs()).rev() { - let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; - let max_limb = 1 << Self::limb_bits(); - let product = (0..max_limb) - .map(|x| this_limb - F::from_canonical_usize(x)) - .product(); - yield_constr.one(product); - - combined_limbs = combined_limbs * limb_base + this_limb; - } - yield_constr.one(combined_limbs - output_result); - - // Range-check output_borrow to be one bit. - yield_constr.one(output_borrow * (P::ONES - output_borrow)); - } - } -} - -#[derive(Clone, Debug)] -struct U32SubtractionGenerator, const D: usize> { - gate: U32SubtractionGate, - row: usize, - i: usize, - _phantom: PhantomData, -} - -impl, const D: usize> SimpleGenerator - for U32SubtractionGenerator -{ - fn dependencies(&self) -> Vec { - let local_target = |column| Target::wire(self.row, column); - - vec![ - local_target(self.gate.wire_ith_input_x(self.i)), - local_target(self.gate.wire_ith_input_y(self.i)), - local_target(self.gate.wire_ith_input_borrow(self.i)), - ] - } - - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let local_wire = |column| Wire { - row: self.row, - column, - }; - - let get_local_wire = |column| witness.get_wire(local_wire(column)); - - let input_x = get_local_wire(self.gate.wire_ith_input_x(self.i)); - let input_y = get_local_wire(self.gate.wire_ith_input_y(self.i)); - let input_borrow = get_local_wire(self.gate.wire_ith_input_borrow(self.i)); - - let result_initial = input_x - input_y - input_borrow; - let result_initial_u64 = result_initial.to_canonical_u64(); - let output_borrow = if result_initial_u64 > 1 << 32u64 { - F::ONE - } else { - F::ZERO - }; - - let base = F::from_canonical_u64(1 << 32u64); - let output_result = result_initial + base * output_borrow; - - let output_result_wire = local_wire(self.gate.wire_ith_output_result(self.i)); - let output_borrow_wire = local_wire(self.gate.wire_ith_output_borrow(self.i)); - - out_buffer.set_wire(output_result_wire, output_result); - out_buffer.set_wire(output_borrow_wire, output_borrow); - - let output_result_u64 = output_result.to_canonical_u64(); - - let num_limbs = U32SubtractionGate::::num_limbs(); - let limb_base = 1 << U32SubtractionGate::::limb_bits(); - let output_limbs: Vec<_> = (0..num_limbs) - .scan(output_result_u64, |acc, _| { - let tmp = *acc % limb_base; - *acc /= limb_base; - Some(F::from_canonical_u64(tmp)) - }) - .collect(); - - for j in 0..num_limbs { - let wire = local_wire(self.gate.wire_ith_output_jth_limb(self.i, j)); - out_buffer.set_wire(wire, output_limbs[j]); - } - } -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - use plonky2::field::extension::quartic::QuarticExtension; - use plonky2::field::goldilocks_field::GoldilocksField; - use plonky2::field::types::{PrimeField64, Sample}; - use plonky2::gates::gate_testing::{test_eval_fns, test_low_degree}; - use plonky2::hash::hash_types::HashOut; - use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use rand::rngs::OsRng; - use rand::Rng; - - use super::*; - - #[test] - fn low_degree() { - test_low_degree::(U32SubtractionGate:: { - num_ops: 3, - _phantom: PhantomData, - }) - } - - #[test] - fn eval_fns() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - test_eval_fns::(U32SubtractionGate:: { - num_ops: 3, - _phantom: PhantomData, - }) - } - - #[test] - fn test_gate_constraint() { - type F = GoldilocksField; - type FF = QuarticExtension; - const D: usize = 4; - const NUM_U32_SUBTRACTION_OPS: usize = 3; - - fn get_wires(inputs_x: Vec, inputs_y: Vec, borrows: Vec) -> Vec { - let mut v0 = Vec::new(); - let mut v1 = Vec::new(); - - let limb_bits = U32SubtractionGate::::limb_bits(); - let num_limbs = U32SubtractionGate::::num_limbs(); - let limb_base = 1 << limb_bits; - for c in 0..NUM_U32_SUBTRACTION_OPS { - let input_x = F::from_canonical_u64(inputs_x[c]); - let input_y = F::from_canonical_u64(inputs_y[c]); - let input_borrow = F::from_canonical_u64(borrows[c]); - - let result_initial = input_x - input_y - input_borrow; - let result_initial_u64 = result_initial.to_canonical_u64(); - let output_borrow = if result_initial_u64 > 1 << 32u64 { - F::ONE - } else { - F::ZERO - }; - - let base = F::from_canonical_u64(1 << 32u64); - let output_result = result_initial + base * output_borrow; - - let output_result_u64 = output_result.to_canonical_u64(); - - let mut output_limbs: Vec<_> = (0..num_limbs) - .scan(output_result_u64, |acc, _| { - let tmp = *acc % limb_base; - *acc /= limb_base; - Some(F::from_canonical_u64(tmp)) - }) - .collect(); - - v0.push(input_x); - v0.push(input_y); - v0.push(input_borrow); - v0.push(output_result); - v0.push(output_borrow); - v1.append(&mut output_limbs); - } - - v0.iter().chain(v1.iter()).map(|&x| x.into()).collect() - } - - let mut rng = OsRng; - let inputs_x = (0..NUM_U32_SUBTRACTION_OPS) - .map(|_| rng.gen::() as u64) - .collect(); - let inputs_y = (0..NUM_U32_SUBTRACTION_OPS) - .map(|_| rng.gen::() as u64) - .collect(); - let borrows = (0..NUM_U32_SUBTRACTION_OPS) - .map(|_| (rng.gen::() % 2) as u64) - .collect(); - - let gate = U32SubtractionGate:: { - num_ops: NUM_U32_SUBTRACTION_OPS, - _phantom: PhantomData, - }; - - let vars = EvaluationVars { - local_constants: &[], - local_wires: &get_wires(inputs_x, inputs_y, borrows), - public_inputs_hash: &HashOut::rand(), - }; - - assert!( - gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()), - "Gate constraints are not satisfied." - ); - } -} diff --git a/u32/src/lib.rs b/u32/src/lib.rs deleted file mode 100644 index 2d8d07f3..00000000 --- a/u32/src/lib.rs +++ /dev/null @@ -1,8 +0,0 @@ -#![allow(clippy::needless_range_loop)] -#![no_std] - -extern crate alloc; - -pub mod gadgets; -pub mod gates; -pub mod witness; diff --git a/u32/src/witness.rs b/u32/src/witness.rs deleted file mode 100644 index cf308d2a..00000000 --- a/u32/src/witness.rs +++ /dev/null @@ -1,33 +0,0 @@ -use plonky2::field::types::{Field, PrimeField64}; -use plonky2::iop::generator::GeneratedValues; -use plonky2::iop::witness::{Witness, WitnessWrite}; - -use crate::gadgets::arithmetic_u32::U32Target; - -pub trait WitnessU32: Witness { - fn set_u32_target(&mut self, target: U32Target, value: u32); - fn get_u32_target(&self, target: U32Target) -> (u32, u32); -} - -impl, F: PrimeField64> WitnessU32 for T { - fn set_u32_target(&mut self, target: U32Target, value: u32) { - self.set_target(target.0, F::from_canonical_u32(value)); - } - - fn get_u32_target(&self, target: U32Target) -> (u32, u32) { - let x_u64 = self.get_target(target.0).to_canonical_u64(); - let low = x_u64 as u32; - let high = (x_u64 >> 32) as u32; - (low, high) - } -} - -pub trait GeneratedValuesU32 { - fn set_u32_target(&mut self, target: U32Target, value: u32); -} - -impl GeneratedValuesU32 for GeneratedValues { - fn set_u32_target(&mut self, target: U32Target, value: u32) { - self.set_target(target.0, F::from_canonical_u32(value)) - } -} diff --git a/waksman/Cargo.toml b/waksman/Cargo.toml deleted file mode 100644 index 98d76a52..00000000 --- a/waksman/Cargo.toml +++ /dev/null @@ -1,15 +0,0 @@ -[package] -name = "plonky2_waksman" -description = "A circuit implementation AS-Waksman networks, useful for checking permutations and sorting" -version = "0.1.0" -edition = "2021" - -[dependencies] -anyhow = "1.0.40" -array_tool = "1.0.3" -bimap = "0.6.1" -itertools = "0.10.0" -"plonky2" = { version = "0.1.0" } -"plonky2_field" = { version = "0.1.0" } -"plonky2_util" = { version = "0.1.0" } -rand = "0.8.4" diff --git a/waksman/LICENSE-APACHE b/waksman/LICENSE-APACHE deleted file mode 100644 index 1e5006dc..00000000 --- a/waksman/LICENSE-APACHE +++ /dev/null @@ -1,202 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - -TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - -1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - -2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - -3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - -4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - -5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - -6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - -7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - -8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - -9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - -END OF TERMS AND CONDITIONS - -APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - -Copyright [yyyy] [name of copyright owner] - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - diff --git a/waksman/LICENSE-MIT b/waksman/LICENSE-MIT deleted file mode 100644 index 86d690b2..00000000 --- a/waksman/LICENSE-MIT +++ /dev/null @@ -1,21 +0,0 @@ -The MIT License (MIT) - -Copyright (c) 2022 The Plonky2 Authors - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. diff --git a/waksman/README.md b/waksman/README.md deleted file mode 100644 index bb4e2d8a..00000000 --- a/waksman/README.md +++ /dev/null @@ -1,13 +0,0 @@ -## License - -Licensed under either of - -* Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) -* MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) - -at your option. - - -### Contribution - -Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in the work by you, as defined in the Apache-2.0 license, shall be dual licensed as above, without any additional terms or conditions. diff --git a/waksman/src/bimap.rs b/waksman/src/bimap.rs deleted file mode 100644 index 28359d9f..00000000 --- a/waksman/src/bimap.rs +++ /dev/null @@ -1,75 +0,0 @@ -use std::collections::HashMap; -use std::hash::Hash; - -use bimap::BiMap; - -/// Given two lists which are permutations of one another, creates a BiMap which maps an index in -/// one list to an index in the other list with the same associated value. -/// -/// If the lists contain duplicates, then multiple permutations with this property exist, and an -/// arbitrary one of them will be returned. -pub fn bimap_from_lists(a: Vec, b: Vec) -> BiMap { - assert_eq!(a.len(), b.len(), "Vectors differ in length"); - - let mut b_values_to_indices = HashMap::new(); - for (i, value) in b.iter().enumerate() { - b_values_to_indices - .entry(value) - .or_insert_with(Vec::new) - .push(i); - } - - let mut bimap = BiMap::new(); - for (i, value) in a.iter().enumerate() { - if let Some(j) = b_values_to_indices.get_mut(&value).and_then(Vec::pop) { - bimap.insert(i, j); - } else { - panic!("Value in first list not found in second list"); - } - } - - bimap -} - -#[cfg(test)] -mod tests { - use crate::bimap::bimap_from_lists; - - #[test] - fn empty_lists() { - let empty: Vec = Vec::new(); - let bimap = bimap_from_lists(empty.clone(), empty); - assert!(bimap.is_empty()); - } - - #[test] - fn without_duplicates() { - let bimap = bimap_from_lists(vec!['a', 'b', 'c'], vec!['b', 'c', 'a']); - assert_eq!(bimap.get_by_left(&0), Some(&2)); - assert_eq!(bimap.get_by_left(&1), Some(&0)); - assert_eq!(bimap.get_by_left(&2), Some(&1)); - } - - #[test] - fn with_duplicates() { - let first = vec!['a', 'a', 'b']; - let second = vec!['a', 'b', 'a']; - let bimap = bimap_from_lists(first.clone(), second.clone()); - for i in 0..3 { - let j = *bimap.get_by_left(&i).unwrap(); - assert_eq!(first[i], second[j]); - } - } - - #[test] - #[should_panic] - fn lengths_differ() { - bimap_from_lists(vec!['a', 'a', 'b'], vec!['a', 'b']); - } - - #[test] - #[should_panic] - fn not_a_permutation() { - bimap_from_lists(vec!['a', 'a', 'b'], vec!['a', 'b', 'b']); - } -} diff --git a/waksman/src/gates/assert_le.rs b/waksman/src/gates/assert_le.rs deleted file mode 100644 index 0213dd38..00000000 --- a/waksman/src/gates/assert_le.rs +++ /dev/null @@ -1,629 +0,0 @@ -use std::marker::PhantomData; - -use plonky2::gates::gate::Gate; -use plonky2::gates::packed_util::PackedEvaluableBase; -use plonky2::gates::util::StridedConstraintConsumer; -use plonky2::hash::hash_types::RichField; -use plonky2::iop::ext_target::ExtensionTarget; -use plonky2::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; -use plonky2::iop::target::Target; -use plonky2::iop::wire::Wire; -use plonky2::iop::witness::{PartitionWitness, Witness, WitnessWrite}; -use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2::plonk::plonk_common::{reduce_with_powers, reduce_with_powers_ext_circuit}; -use plonky2::plonk::vars::{ - EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch, - EvaluationVarsBasePacked, -}; -use plonky2_field::extension::Extendable; -use plonky2_field::packed::PackedField; -use plonky2_field::types::{Field, Field64}; -use plonky2_util::{bits_u64, ceil_div_usize}; - -// TODO: replace/merge this gate with `ComparisonGate`. - -/// A gate for checking that one value is less than or equal to another. -#[derive(Clone, Debug)] -pub struct AssertLessThanGate, const D: usize> { - pub(crate) num_bits: usize, - pub(crate) num_chunks: usize, - _phantom: PhantomData, -} - -impl, const D: usize> AssertLessThanGate { - pub fn new(num_bits: usize, num_chunks: usize) -> Self { - debug_assert!(num_bits < bits_u64(F::ORDER)); - Self { - num_bits, - num_chunks, - _phantom: PhantomData, - } - } - - pub fn chunk_bits(&self) -> usize { - ceil_div_usize(self.num_bits, self.num_chunks) - } - - pub fn wire_first_input(&self) -> usize { - 0 - } - - pub fn wire_second_input(&self) -> usize { - 1 - } - - pub fn wire_most_significant_diff(&self) -> usize { - 2 - } - - pub fn wire_first_chunk_val(&self, chunk: usize) -> usize { - debug_assert!(chunk < self.num_chunks); - 3 + chunk - } - - pub fn wire_second_chunk_val(&self, chunk: usize) -> usize { - debug_assert!(chunk < self.num_chunks); - 3 + self.num_chunks + chunk - } - - pub fn wire_equality_dummy(&self, chunk: usize) -> usize { - debug_assert!(chunk < self.num_chunks); - 3 + 2 * self.num_chunks + chunk - } - - pub fn wire_chunks_equal(&self, chunk: usize) -> usize { - debug_assert!(chunk < self.num_chunks); - 3 + 3 * self.num_chunks + chunk - } - - pub fn wire_intermediate_value(&self, chunk: usize) -> usize { - debug_assert!(chunk < self.num_chunks); - 3 + 4 * self.num_chunks + chunk - } -} - -impl, const D: usize> Gate for AssertLessThanGate { - fn id(&self) -> String { - format!("{self:?}") - } - - fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { - let mut constraints = Vec::with_capacity(self.num_constraints()); - - let first_input = vars.local_wires[self.wire_first_input()]; - let second_input = vars.local_wires[self.wire_second_input()]; - - // Get chunks and assert that they match - let first_chunks: Vec = (0..self.num_chunks) - .map(|i| vars.local_wires[self.wire_first_chunk_val(i)]) - .collect(); - let second_chunks: Vec = (0..self.num_chunks) - .map(|i| vars.local_wires[self.wire_second_chunk_val(i)]) - .collect(); - - let first_chunks_combined = reduce_with_powers( - &first_chunks, - F::Extension::from_canonical_usize(1 << self.chunk_bits()), - ); - let second_chunks_combined = reduce_with_powers( - &second_chunks, - F::Extension::from_canonical_usize(1 << self.chunk_bits()), - ); - - constraints.push(first_chunks_combined - first_input); - constraints.push(second_chunks_combined - second_input); - - let chunk_size = 1 << self.chunk_bits(); - - let mut most_significant_diff_so_far = F::Extension::ZERO; - - for i in 0..self.num_chunks { - // Range-check the chunks to be less than `chunk_size`. - let first_product = (0..chunk_size) - .map(|x| first_chunks[i] - F::Extension::from_canonical_usize(x)) - .product(); - let second_product = (0..chunk_size) - .map(|x| second_chunks[i] - F::Extension::from_canonical_usize(x)) - .product(); - constraints.push(first_product); - constraints.push(second_product); - - let difference = second_chunks[i] - first_chunks[i]; - let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)]; - let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)]; - - // Two constraints to assert that `chunks_equal` is valid. - constraints.push(difference * equality_dummy - (F::Extension::ONE - chunks_equal)); - constraints.push(chunks_equal * difference); - - // Update `most_significant_diff_so_far`. - let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)]; - constraints.push(intermediate_value - chunks_equal * most_significant_diff_so_far); - most_significant_diff_so_far = - intermediate_value + (F::Extension::ONE - chunks_equal) * difference; - } - - let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; - constraints.push(most_significant_diff - most_significant_diff_so_far); - - // Range check `most_significant_diff` to be less than `chunk_size`. - let product = (0..chunk_size) - .map(|x| most_significant_diff - F::Extension::from_canonical_usize(x)) - .product(); - constraints.push(product); - - constraints - } - - fn eval_unfiltered_base_one( - &self, - _vars: EvaluationVarsBase, - _yield_constr: StridedConstraintConsumer, - ) { - panic!("use eval_unfiltered_base_packed instead"); - } - - fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch) -> Vec { - self.eval_unfiltered_base_batch_packed(vars_base) - } - - fn eval_unfiltered_circuit( - &self, - builder: &mut CircuitBuilder, - vars: EvaluationTargets, - ) -> Vec> { - let mut constraints = Vec::with_capacity(self.num_constraints()); - - let first_input = vars.local_wires[self.wire_first_input()]; - let second_input = vars.local_wires[self.wire_second_input()]; - - // Get chunks and assert that they match - let first_chunks: Vec> = (0..self.num_chunks) - .map(|i| vars.local_wires[self.wire_first_chunk_val(i)]) - .collect(); - let second_chunks: Vec> = (0..self.num_chunks) - .map(|i| vars.local_wires[self.wire_second_chunk_val(i)]) - .collect(); - - let chunk_base = builder.constant(F::from_canonical_usize(1 << self.chunk_bits())); - let first_chunks_combined = - reduce_with_powers_ext_circuit(builder, &first_chunks, chunk_base); - let second_chunks_combined = - reduce_with_powers_ext_circuit(builder, &second_chunks, chunk_base); - - constraints.push(builder.sub_extension(first_chunks_combined, first_input)); - constraints.push(builder.sub_extension(second_chunks_combined, second_input)); - - let chunk_size = 1 << self.chunk_bits(); - - let mut most_significant_diff_so_far = builder.zero_extension(); - - let one = builder.one_extension(); - // Find the chosen chunk. - for i in 0..self.num_chunks { - // Range-check the chunks to be less than `chunk_size`. - let mut first_product = one; - let mut second_product = one; - for x in 0..chunk_size { - let x_f = builder.constant_extension(F::Extension::from_canonical_usize(x)); - let first_diff = builder.sub_extension(first_chunks[i], x_f); - let second_diff = builder.sub_extension(second_chunks[i], x_f); - first_product = builder.mul_extension(first_product, first_diff); - second_product = builder.mul_extension(second_product, second_diff); - } - constraints.push(first_product); - constraints.push(second_product); - - let difference = builder.sub_extension(second_chunks[i], first_chunks[i]); - let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)]; - let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)]; - - // Two constraints to assert that `chunks_equal` is valid. - let diff_times_equal = builder.mul_extension(difference, equality_dummy); - let not_equal = builder.sub_extension(one, chunks_equal); - constraints.push(builder.sub_extension(diff_times_equal, not_equal)); - constraints.push(builder.mul_extension(chunks_equal, difference)); - - // Update `most_significant_diff_so_far`. - let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)]; - let old_diff = builder.mul_extension(chunks_equal, most_significant_diff_so_far); - constraints.push(builder.sub_extension(intermediate_value, old_diff)); - - let not_equal = builder.sub_extension(one, chunks_equal); - let new_diff = builder.mul_extension(not_equal, difference); - most_significant_diff_so_far = builder.add_extension(intermediate_value, new_diff); - } - - let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; - constraints - .push(builder.sub_extension(most_significant_diff, most_significant_diff_so_far)); - - // Range check `most_significant_diff` to be less than `chunk_size`. - let mut product = builder.one_extension(); - for x in 0..chunk_size { - let x_f = builder.constant_extension(F::Extension::from_canonical_usize(x)); - let diff = builder.sub_extension(most_significant_diff, x_f); - product = builder.mul_extension(product, diff); - } - constraints.push(product); - - constraints - } - - fn generators(&self, row: usize, _local_constants: &[F]) -> Vec>> { - let gen = AssertLessThanGenerator:: { - row, - gate: self.clone(), - }; - vec![Box::new(gen.adapter())] - } - - fn num_wires(&self) -> usize { - self.wire_intermediate_value(self.num_chunks - 1) + 1 - } - - fn num_constants(&self) -> usize { - 0 - } - - fn degree(&self) -> usize { - 1 << self.chunk_bits() - } - - fn num_constraints(&self) -> usize { - 4 + 5 * self.num_chunks - } -} - -impl, const D: usize> PackedEvaluableBase - for AssertLessThanGate -{ - fn eval_unfiltered_base_packed>( - &self, - vars: EvaluationVarsBasePacked

, - mut yield_constr: StridedConstraintConsumer

, - ) { - let first_input = vars.local_wires[self.wire_first_input()]; - let second_input = vars.local_wires[self.wire_second_input()]; - - // Get chunks and assert that they match - let first_chunks: Vec<_> = (0..self.num_chunks) - .map(|i| vars.local_wires[self.wire_first_chunk_val(i)]) - .collect(); - let second_chunks: Vec<_> = (0..self.num_chunks) - .map(|i| vars.local_wires[self.wire_second_chunk_val(i)]) - .collect(); - - let first_chunks_combined = reduce_with_powers( - &first_chunks, - F::from_canonical_usize(1 << self.chunk_bits()), - ); - let second_chunks_combined = reduce_with_powers( - &second_chunks, - F::from_canonical_usize(1 << self.chunk_bits()), - ); - - yield_constr.one(first_chunks_combined - first_input); - yield_constr.one(second_chunks_combined - second_input); - - let chunk_size = 1 << self.chunk_bits(); - - let mut most_significant_diff_so_far = P::ZEROS; - - for i in 0..self.num_chunks { - // Range-check the chunks to be less than `chunk_size`. - let first_product = (0..chunk_size) - .map(|x| first_chunks[i] - F::from_canonical_usize(x)) - .product(); - let second_product = (0..chunk_size) - .map(|x| second_chunks[i] - F::from_canonical_usize(x)) - .product(); - yield_constr.one(first_product); - yield_constr.one(second_product); - - let difference = second_chunks[i] - first_chunks[i]; - let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)]; - let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)]; - - // Two constraints to assert that `chunks_equal` is valid. - yield_constr.one(difference * equality_dummy - (P::ONES - chunks_equal)); - yield_constr.one(chunks_equal * difference); - - // Update `most_significant_diff_so_far`. - let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)]; - yield_constr.one(intermediate_value - chunks_equal * most_significant_diff_so_far); - most_significant_diff_so_far = - intermediate_value + (P::ONES - chunks_equal) * difference; - } - - let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; - yield_constr.one(most_significant_diff - most_significant_diff_so_far); - - // Range check `most_significant_diff` to be less than `chunk_size`. - let product = (0..chunk_size) - .map(|x| most_significant_diff - F::from_canonical_usize(x)) - .product(); - yield_constr.one(product); - } -} - -#[derive(Debug)] -struct AssertLessThanGenerator, const D: usize> { - row: usize, - gate: AssertLessThanGate, -} - -impl, const D: usize> SimpleGenerator - for AssertLessThanGenerator -{ - fn dependencies(&self) -> Vec { - let local_target = |column| Target::wire(self.row, column); - - vec![ - local_target(self.gate.wire_first_input()), - local_target(self.gate.wire_second_input()), - ] - } - - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let local_wire = |column| Wire { - row: self.row, - column, - }; - - let get_local_wire = |column| witness.get_wire(local_wire(column)); - - let first_input = get_local_wire(self.gate.wire_first_input()); - let second_input = get_local_wire(self.gate.wire_second_input()); - - let first_input_u64 = first_input.to_canonical_u64(); - let second_input_u64 = second_input.to_canonical_u64(); - - debug_assert!(first_input_u64 < second_input_u64); - - let chunk_size = 1 << self.gate.chunk_bits(); - let first_input_chunks: Vec = (0..self.gate.num_chunks) - .scan(first_input_u64, |acc, _| { - let tmp = *acc % chunk_size; - *acc /= chunk_size; - Some(F::from_canonical_u64(tmp)) - }) - .collect(); - let second_input_chunks: Vec = (0..self.gate.num_chunks) - .scan(second_input_u64, |acc, _| { - let tmp = *acc % chunk_size; - *acc /= chunk_size; - Some(F::from_canonical_u64(tmp)) - }) - .collect(); - - let chunks_equal: Vec = (0..self.gate.num_chunks) - .map(|i| F::from_bool(first_input_chunks[i] == second_input_chunks[i])) - .collect(); - let equality_dummies: Vec = first_input_chunks - .iter() - .zip(second_input_chunks.iter()) - .map(|(&f, &s)| if f == s { F::ONE } else { F::ONE / (s - f) }) - .collect(); - - let mut most_significant_diff_so_far = F::ZERO; - let mut intermediate_values = Vec::new(); - for i in 0..self.gate.num_chunks { - if first_input_chunks[i] != second_input_chunks[i] { - most_significant_diff_so_far = second_input_chunks[i] - first_input_chunks[i]; - intermediate_values.push(F::ZERO); - } else { - intermediate_values.push(most_significant_diff_so_far); - } - } - let most_significant_diff = most_significant_diff_so_far; - - out_buffer.set_wire( - local_wire(self.gate.wire_most_significant_diff()), - most_significant_diff, - ); - for i in 0..self.gate.num_chunks { - out_buffer.set_wire( - local_wire(self.gate.wire_first_chunk_val(i)), - first_input_chunks[i], - ); - out_buffer.set_wire( - local_wire(self.gate.wire_second_chunk_val(i)), - second_input_chunks[i], - ); - out_buffer.set_wire( - local_wire(self.gate.wire_equality_dummy(i)), - equality_dummies[i], - ); - out_buffer.set_wire(local_wire(self.gate.wire_chunks_equal(i)), chunks_equal[i]); - out_buffer.set_wire( - local_wire(self.gate.wire_intermediate_value(i)), - intermediate_values[i], - ); - } - } -} - -#[cfg(test)] -mod tests { - use core::marker::PhantomData; - - use anyhow::Result; - use plonky2::gates::gate::Gate; - use plonky2::gates::gate_testing::{test_eval_fns, test_low_degree}; - use plonky2::hash::hash_types::HashOut; - use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use plonky2::plonk::vars::EvaluationVars; - use plonky2_field::extension::quartic::QuarticExtension; - use plonky2_field::goldilocks_field::GoldilocksField; - use plonky2_field::types::{Field, PrimeField64, Sample}; - use rand::Rng; - - use crate::gates::assert_le::AssertLessThanGate; - - #[test] - fn wire_indices() { - type AG = AssertLessThanGate; - let num_bits = 40; - let num_chunks = 5; - - let gate = AG { - num_bits, - num_chunks, - _phantom: PhantomData, - }; - - assert_eq!(gate.wire_first_input(), 0); - assert_eq!(gate.wire_second_input(), 1); - assert_eq!(gate.wire_most_significant_diff(), 2); - assert_eq!(gate.wire_first_chunk_val(0), 3); - assert_eq!(gate.wire_first_chunk_val(4), 7); - assert_eq!(gate.wire_second_chunk_val(0), 8); - assert_eq!(gate.wire_second_chunk_val(4), 12); - assert_eq!(gate.wire_equality_dummy(0), 13); - assert_eq!(gate.wire_equality_dummy(4), 17); - assert_eq!(gate.wire_chunks_equal(0), 18); - assert_eq!(gate.wire_chunks_equal(4), 22); - assert_eq!(gate.wire_intermediate_value(0), 23); - assert_eq!(gate.wire_intermediate_value(4), 27); - } - - #[test] - fn low_degree() { - let num_bits = 20; - let num_chunks = 4; - - test_low_degree::(AssertLessThanGate::<_, 4>::new( - num_bits, num_chunks, - )) - } - - #[test] - fn eval_fns() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let num_bits = 20; - let num_chunks = 4; - - test_eval_fns::(AssertLessThanGate::<_, D>::new(num_bits, num_chunks)) - } - - #[test] - fn test_gate_constraint() { - type F = GoldilocksField; - type FF = QuarticExtension; - const D: usize = 4; - - let num_bits = 40; - let num_chunks = 5; - let chunk_bits = num_bits / num_chunks; - - // Returns the local wires for an AssertLessThanGate given the two inputs. - let get_wires = |first_input: F, second_input: F| -> Vec { - let mut v = Vec::new(); - - let first_input_u64 = first_input.to_canonical_u64(); - let second_input_u64 = second_input.to_canonical_u64(); - - let chunk_size = 1 << chunk_bits; - let mut first_input_chunks: Vec = (0..num_chunks) - .scan(first_input_u64, |acc, _| { - let tmp = *acc % chunk_size; - *acc /= chunk_size; - Some(F::from_canonical_u64(tmp)) - }) - .collect(); - let mut second_input_chunks: Vec = (0..num_chunks) - .scan(second_input_u64, |acc, _| { - let tmp = *acc % chunk_size; - *acc /= chunk_size; - Some(F::from_canonical_u64(tmp)) - }) - .collect(); - - let mut chunks_equal: Vec = (0..num_chunks) - .map(|i| F::from_bool(first_input_chunks[i] == second_input_chunks[i])) - .collect(); - let mut equality_dummies: Vec = first_input_chunks - .iter() - .zip(second_input_chunks.iter()) - .map(|(&f, &s)| if f == s { F::ONE } else { F::ONE / (s - f) }) - .collect(); - - let mut most_significant_diff_so_far = F::ZERO; - let mut intermediate_values = Vec::new(); - for i in 0..num_chunks { - if first_input_chunks[i] != second_input_chunks[i] { - most_significant_diff_so_far = second_input_chunks[i] - first_input_chunks[i]; - intermediate_values.push(F::ZERO); - } else { - intermediate_values.push(most_significant_diff_so_far); - } - } - let most_significant_diff = most_significant_diff_so_far; - - v.push(first_input); - v.push(second_input); - v.push(most_significant_diff); - v.append(&mut first_input_chunks); - v.append(&mut second_input_chunks); - v.append(&mut equality_dummies); - v.append(&mut chunks_equal); - v.append(&mut intermediate_values); - - v.iter().map(|&x| x.into()).collect() - }; - - let mut rng = rand::thread_rng(); - let max: u64 = 1 << (num_bits - 1); - let first_input_u64 = rng.gen_range(0..max); - let second_input_u64 = { - let mut val = rng.gen_range(0..max); - while val < first_input_u64 { - val = rng.gen_range(0..max); - } - val - }; - - let first_input = F::from_canonical_u64(first_input_u64); - let second_input = F::from_canonical_u64(second_input_u64); - - let less_than_gate = AssertLessThanGate:: { - num_bits, - num_chunks, - _phantom: PhantomData, - }; - let less_than_vars = EvaluationVars { - local_constants: &[], - local_wires: &get_wires(first_input, second_input), - public_inputs_hash: &HashOut::rand(), - }; - assert!( - less_than_gate - .eval_unfiltered(less_than_vars) - .iter() - .all(|x| x.is_zero()), - "Gate constraints are not satisfied." - ); - - let equal_gate = AssertLessThanGate:: { - num_bits, - num_chunks, - _phantom: PhantomData, - }; - let equal_vars = EvaluationVars { - local_constants: &[], - local_wires: &get_wires(first_input, first_input), - public_inputs_hash: &HashOut::rand(), - }; - assert!( - equal_gate - .eval_unfiltered(equal_vars) - .iter() - .all(|x| x.is_zero()), - "Gate constraints are not satisfied." - ); - } -} diff --git a/waksman/src/gates/mod.rs b/waksman/src/gates/mod.rs deleted file mode 100644 index c73890b1..00000000 --- a/waksman/src/gates/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod assert_le; -pub mod switch; diff --git a/waksman/src/gates/switch.rs b/waksman/src/gates/switch.rs deleted file mode 100644 index b868916e..00000000 --- a/waksman/src/gates/switch.rs +++ /dev/null @@ -1,454 +0,0 @@ -use std::marker::PhantomData; - -use array_tool::vec::Union; -use plonky2::gates::gate::Gate; -use plonky2::gates::packed_util::PackedEvaluableBase; -use plonky2::gates::util::StridedConstraintConsumer; -use plonky2::hash::hash_types::RichField; -use plonky2::iop::ext_target::ExtensionTarget; -use plonky2::iop::generator::{GeneratedValues, WitnessGenerator}; -use plonky2::iop::target::Target; -use plonky2::iop::wire::Wire; -use plonky2::iop::witness::{PartitionWitness, Witness, WitnessWrite}; -use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2::plonk::circuit_data::CircuitConfig; -use plonky2::plonk::vars::{ - EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch, - EvaluationVarsBasePacked, -}; -use plonky2_field::extension::Extendable; -use plonky2_field::packed::PackedField; -use plonky2_field::types::Field; - -/// A gate for conditionally swapping input values based on a boolean. -#[derive(Copy, Clone, Debug)] -pub struct SwitchGate, const D: usize> { - pub(crate) chunk_size: usize, - pub(crate) num_copies: usize, - _phantom: PhantomData, -} - -impl, const D: usize> SwitchGate { - pub fn new(num_copies: usize, chunk_size: usize) -> Self { - Self { - chunk_size, - num_copies, - _phantom: PhantomData, - } - } - - pub fn new_from_config(config: &CircuitConfig, chunk_size: usize) -> Self { - let num_copies = Self::max_num_copies(config.num_routed_wires, chunk_size); - Self::new(num_copies, chunk_size) - } - - pub fn max_num_copies(num_routed_wires: usize, chunk_size: usize) -> usize { - num_routed_wires / (4 * chunk_size + 1) - } - - pub fn wire_first_input(&self, copy: usize, element: usize) -> usize { - debug_assert!(element < self.chunk_size); - copy * (4 * self.chunk_size + 1) + element - } - - pub fn wire_second_input(&self, copy: usize, element: usize) -> usize { - debug_assert!(element < self.chunk_size); - copy * (4 * self.chunk_size + 1) + self.chunk_size + element - } - - pub fn wire_first_output(&self, copy: usize, element: usize) -> usize { - debug_assert!(element < self.chunk_size); - copy * (4 * self.chunk_size + 1) + 2 * self.chunk_size + element - } - - pub fn wire_second_output(&self, copy: usize, element: usize) -> usize { - debug_assert!(element < self.chunk_size); - copy * (4 * self.chunk_size + 1) + 3 * self.chunk_size + element - } - - pub fn wire_switch_bool(&self, copy: usize) -> usize { - debug_assert!(copy < self.num_copies); - copy * (4 * self.chunk_size + 1) + 4 * self.chunk_size - } -} - -impl, const D: usize> Gate for SwitchGate { - fn id(&self) -> String { - format!("{self:?}") - } - - fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { - let mut constraints = Vec::with_capacity(self.num_constraints()); - - for c in 0..self.num_copies { - let switch_bool = vars.local_wires[self.wire_switch_bool(c)]; - let not_switch = F::Extension::ONE - switch_bool; - - for e in 0..self.chunk_size { - let first_input = vars.local_wires[self.wire_first_input(c, e)]; - let second_input = vars.local_wires[self.wire_second_input(c, e)]; - let first_output = vars.local_wires[self.wire_first_output(c, e)]; - let second_output = vars.local_wires[self.wire_second_output(c, e)]; - - constraints.push(switch_bool * (first_input - second_output)); - constraints.push(switch_bool * (second_input - first_output)); - constraints.push(not_switch * (first_input - first_output)); - constraints.push(not_switch * (second_input - second_output)); - } - } - - constraints - } - - fn eval_unfiltered_base_one( - &self, - _vars: EvaluationVarsBase, - _yield_constr: StridedConstraintConsumer, - ) { - panic!("use eval_unfiltered_base_packed instead"); - } - - fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch) -> Vec { - self.eval_unfiltered_base_batch_packed(vars_base) - } - - fn eval_unfiltered_circuit( - &self, - builder: &mut CircuitBuilder, - vars: EvaluationTargets, - ) -> Vec> { - let mut constraints = Vec::with_capacity(self.num_constraints()); - - let one = builder.one_extension(); - for c in 0..self.num_copies { - let switch_bool = vars.local_wires[self.wire_switch_bool(c)]; - let not_switch = builder.sub_extension(one, switch_bool); - - for e in 0..self.chunk_size { - let first_input = vars.local_wires[self.wire_first_input(c, e)]; - let second_input = vars.local_wires[self.wire_second_input(c, e)]; - let first_output = vars.local_wires[self.wire_first_output(c, e)]; - let second_output = vars.local_wires[self.wire_second_output(c, e)]; - - let first_switched = builder.sub_extension(first_input, second_output); - let first_switched_constraint = builder.mul_extension(switch_bool, first_switched); - constraints.push(first_switched_constraint); - - let second_switched = builder.sub_extension(second_input, first_output); - let second_switched_constraint = - builder.mul_extension(switch_bool, second_switched); - constraints.push(second_switched_constraint); - - let first_not_switched = builder.sub_extension(first_input, first_output); - let first_not_switched_constraint = - builder.mul_extension(not_switch, first_not_switched); - constraints.push(first_not_switched_constraint); - - let second_not_switched = builder.sub_extension(second_input, second_output); - let second_not_switched_constraint = - builder.mul_extension(not_switch, second_not_switched); - constraints.push(second_not_switched_constraint); - } - } - - constraints - } - - fn generators(&self, row: usize, _local_constants: &[F]) -> Vec>> { - (0..self.num_copies) - .map(|c| { - let g: Box> = Box::new(SwitchGenerator:: { - row, - gate: *self, - copy: c, - }); - g - }) - .collect() - } - - fn num_wires(&self) -> usize { - self.wire_switch_bool(self.num_copies - 1) + 1 - } - - fn num_constants(&self) -> usize { - 0 - } - - fn degree(&self) -> usize { - 2 - } - - fn num_constraints(&self) -> usize { - 4 * self.num_copies * self.chunk_size - } -} - -impl, const D: usize> PackedEvaluableBase for SwitchGate { - fn eval_unfiltered_base_packed>( - &self, - vars: EvaluationVarsBasePacked

, - mut yield_constr: StridedConstraintConsumer

, - ) { - for c in 0..self.num_copies { - let switch_bool = vars.local_wires[self.wire_switch_bool(c)]; - let not_switch = P::ONES - switch_bool; - - for e in 0..self.chunk_size { - let first_input = vars.local_wires[self.wire_first_input(c, e)]; - let second_input = vars.local_wires[self.wire_second_input(c, e)]; - let first_output = vars.local_wires[self.wire_first_output(c, e)]; - let second_output = vars.local_wires[self.wire_second_output(c, e)]; - - yield_constr.one(switch_bool * (first_input - second_output)); - yield_constr.one(switch_bool * (second_input - first_output)); - yield_constr.one(not_switch * (first_input - first_output)); - yield_constr.one(not_switch * (second_input - second_output)); - } - } - } -} - -#[derive(Debug)] -struct SwitchGenerator, const D: usize> { - row: usize, - gate: SwitchGate, - copy: usize, -} - -impl, const D: usize> SwitchGenerator { - fn in_out_dependencies(&self) -> Vec { - let local_target = |column| Target::wire(self.row, column); - - let mut deps = Vec::new(); - for e in 0..self.gate.chunk_size { - deps.push(local_target(self.gate.wire_first_input(self.copy, e))); - deps.push(local_target(self.gate.wire_second_input(self.copy, e))); - deps.push(local_target(self.gate.wire_first_output(self.copy, e))); - deps.push(local_target(self.gate.wire_second_output(self.copy, e))); - } - - deps - } - - fn in_switch_dependencies(&self) -> Vec { - let local_target = |column| Target::wire(self.row, column); - - let mut deps = Vec::new(); - for e in 0..self.gate.chunk_size { - deps.push(local_target(self.gate.wire_first_input(self.copy, e))); - deps.push(local_target(self.gate.wire_second_input(self.copy, e))); - deps.push(local_target(self.gate.wire_switch_bool(self.copy))); - } - - deps - } - - fn run_in_out(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let local_wire = |column| Wire { - row: self.row, - column, - }; - - let get_local_wire = |column| witness.get_wire(local_wire(column)); - - let switch_bool_wire = local_wire(self.gate.wire_switch_bool(self.copy)); - - let mut first_inputs = Vec::new(); - let mut second_inputs = Vec::new(); - let mut first_outputs = Vec::new(); - let mut second_outputs = Vec::new(); - for e in 0..self.gate.chunk_size { - first_inputs.push(get_local_wire(self.gate.wire_first_input(self.copy, e))); - second_inputs.push(get_local_wire(self.gate.wire_second_input(self.copy, e))); - first_outputs.push(get_local_wire(self.gate.wire_first_output(self.copy, e))); - second_outputs.push(get_local_wire(self.gate.wire_second_output(self.copy, e))); - } - - if first_outputs == first_inputs && second_outputs == second_inputs { - out_buffer.set_wire(switch_bool_wire, F::ZERO); - } else if first_outputs == second_inputs && second_outputs == first_inputs { - out_buffer.set_wire(switch_bool_wire, F::ONE); - } else { - panic!("No permutation from given inputs to given outputs"); - } - } - - fn run_in_switch(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let local_wire = |column| Wire { - row: self.row, - column, - }; - - let get_local_wire = |column| witness.get_wire(local_wire(column)); - - let switch_bool = get_local_wire(self.gate.wire_switch_bool(self.copy)); - for e in 0..self.gate.chunk_size { - let first_output_wire = local_wire(self.gate.wire_first_output(self.copy, e)); - let second_output_wire = local_wire(self.gate.wire_second_output(self.copy, e)); - let first_input = get_local_wire(self.gate.wire_first_input(self.copy, e)); - let second_input = get_local_wire(self.gate.wire_second_input(self.copy, e)); - - let (first_output, second_output) = if switch_bool == F::ZERO { - (first_input, second_input) - } else if switch_bool == F::ONE { - (second_input, first_input) - } else { - panic!("Invalid switch bool value"); - }; - - out_buffer.set_wire(first_output_wire, first_output); - out_buffer.set_wire(second_output_wire, second_output); - } - } -} - -impl, const D: usize> WitnessGenerator for SwitchGenerator { - fn watch_list(&self) -> Vec { - self.in_out_dependencies() - .union(self.in_switch_dependencies()) - } - - fn run(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) -> bool { - if witness.contains_all(&self.in_out_dependencies()) { - self.run_in_out(witness, out_buffer); - true - } else if witness.contains_all(&self.in_switch_dependencies()) { - self.run_in_switch(witness, out_buffer); - true - } else { - false - } - } -} - -#[cfg(test)] -mod tests { - use std::marker::PhantomData; - - use anyhow::Result; - use plonky2::gates::gate::Gate; - use plonky2::gates::gate_testing::{test_eval_fns, test_low_degree}; - use plonky2::hash::hash_types::HashOut; - use plonky2::plonk::circuit_data::CircuitConfig; - use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use plonky2::plonk::vars::EvaluationVars; - use plonky2_field::goldilocks_field::GoldilocksField; - use plonky2_field::types::{Field, Sample}; - - use crate::gates::switch::SwitchGate; - - #[test] - fn wire_indices() { - type SG = SwitchGate; - let num_copies = 3; - let chunk_size = 3; - - let gate = SG { - chunk_size, - num_copies, - _phantom: PhantomData, - }; - - assert_eq!(gate.wire_first_input(0, 0), 0); - assert_eq!(gate.wire_first_input(0, 2), 2); - assert_eq!(gate.wire_second_input(0, 0), 3); - assert_eq!(gate.wire_second_input(0, 2), 5); - assert_eq!(gate.wire_first_output(0, 0), 6); - assert_eq!(gate.wire_second_output(0, 2), 11); - assert_eq!(gate.wire_switch_bool(0), 12); - assert_eq!(gate.wire_first_input(1, 0), 13); - assert_eq!(gate.wire_second_output(1, 2), 24); - assert_eq!(gate.wire_switch_bool(1), 25); - assert_eq!(gate.wire_first_input(2, 0), 26); - assert_eq!(gate.wire_second_output(2, 2), 37); - assert_eq!(gate.wire_switch_bool(2), 38); - } - - #[test] - fn low_degree() { - test_low_degree::(SwitchGate::<_, 4>::new_from_config( - &CircuitConfig::standard_recursion_config(), - 3, - )); - } - - #[test] - fn eval_fns() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - test_eval_fns::(SwitchGate::<_, D>::new_from_config( - &CircuitConfig::standard_recursion_config(), - 3, - )) - } - - #[test] - fn test_gate_constraint() { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - type FF = >::FE; - const CHUNK_SIZE: usize = 4; - let num_copies = 3; - - /// Returns the local wires for a switch gate given the inputs and the switch booleans. - fn get_wires( - first_inputs: Vec>, - second_inputs: Vec>, - switch_bools: Vec, - ) -> Vec { - let num_copies = first_inputs.len(); - - let mut v = Vec::new(); - for c in 0..num_copies { - let switch = switch_bools[c]; - - let mut first_input_chunk = Vec::with_capacity(CHUNK_SIZE); - let mut second_input_chunk = Vec::with_capacity(CHUNK_SIZE); - let mut first_output_chunk = Vec::with_capacity(CHUNK_SIZE); - let mut second_output_chunk = Vec::with_capacity(CHUNK_SIZE); - for e in 0..CHUNK_SIZE { - let first_input = first_inputs[c][e]; - let second_input = second_inputs[c][e]; - let first_output = if switch { second_input } else { first_input }; - let second_output = if switch { first_input } else { second_input }; - first_input_chunk.push(first_input); - second_input_chunk.push(second_input); - first_output_chunk.push(first_output); - second_output_chunk.push(second_output); - } - v.append(&mut first_input_chunk); - v.append(&mut second_input_chunk); - v.append(&mut first_output_chunk); - v.append(&mut second_output_chunk); - - v.push(F::from_bool(switch)); - } - - v.iter().map(|&x| x.into()).collect() - } - - let first_inputs: Vec> = (0..num_copies).map(|_| F::rand_vec(CHUNK_SIZE)).collect(); - let second_inputs: Vec> = (0..num_copies).map(|_| F::rand_vec(CHUNK_SIZE)).collect(); - let switch_bools = vec![true, false, true]; - - let gate = SwitchGate:: { - chunk_size: CHUNK_SIZE, - num_copies, - _phantom: PhantomData, - }; - - let vars = EvaluationVars { - local_constants: &[], - local_wires: &get_wires(first_inputs, second_inputs, switch_bools), - public_inputs_hash: &HashOut::rand(), - }; - - assert!( - gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()), - "Gate constraints are not satisfied." - ); - } -} diff --git a/waksman/src/lib.rs b/waksman/src/lib.rs deleted file mode 100644 index e9b0d4c5..00000000 --- a/waksman/src/lib.rs +++ /dev/null @@ -1,11 +0,0 @@ -#![allow(clippy::new_without_default)] -#![allow(clippy::too_many_arguments)] -#![allow(clippy::type_complexity)] -#![allow(clippy::len_without_is_empty)] -#![allow(clippy::needless_range_loop)] -#![allow(clippy::return_self_not_must_use)] - -pub mod bimap; -pub mod gates; -pub mod permutation; -pub mod sorting; diff --git a/waksman/src/permutation.rs b/waksman/src/permutation.rs deleted file mode 100644 index 57ede529..00000000 --- a/waksman/src/permutation.rs +++ /dev/null @@ -1,509 +0,0 @@ -use std::collections::BTreeMap; -use std::marker::PhantomData; - -use plonky2::field::extension::Extendable; -use plonky2::field::types::Field; -use plonky2::hash::hash_types::RichField; -use plonky2::iop::generator::{GeneratedValues, SimpleGenerator}; -use plonky2::iop::target::Target; -use plonky2::iop::witness::{PartitionWitness, Witness, WitnessWrite}; -use plonky2::plonk::circuit_builder::CircuitBuilder; - -use crate::bimap::bimap_from_lists; -use crate::gates::switch::SwitchGate; - -/// Assert that two lists of expressions evaluate to permutations of one another. -pub fn assert_permutation_circuit, const D: usize>( - builder: &mut CircuitBuilder, - a: Vec>, - b: Vec>, -) { - assert_eq!( - a.len(), - b.len(), - "Permutation must have same number of inputs and outputs" - ); - assert_eq!(a[0].len(), b[0].len(), "Chunk size must be the same"); - - let chunk_size = a[0].len(); - - match a.len() { - // Two empty lists are permutations of one another, trivially. - 0 => (), - // Two singleton lists are permutations of one another as long as their items are equal. - 1 => { - for e in 0..chunk_size { - builder.connect(a[0][e], b[0][e]) - } - } - 2 => assert_permutation_2x2_circuit( - builder, - a[0].clone(), - a[1].clone(), - b[0].clone(), - b[1].clone(), - ), - // For larger lists, we recursively use two smaller permutation networks. - _ => assert_permutation_helper_circuit(builder, a, b), - } -} - -/// Assert that [a1, a2] is a permutation of [b1, b2]. -fn assert_permutation_2x2_circuit, const D: usize>( - builder: &mut CircuitBuilder, - a1: Vec, - a2: Vec, - b1: Vec, - b2: Vec, -) { - assert!( - a1.len() == a2.len() && a2.len() == b1.len() && b1.len() == b2.len(), - "Chunk size must be the same" - ); - - let chunk_size = a1.len(); - - let (_switch, gate_out1, gate_out2) = create_switch_circuit(builder, a1, a2); - for e in 0..chunk_size { - builder.connect(b1[e], gate_out1[e]); - builder.connect(b2[e], gate_out2[e]); - } -} - -/// Given two input wire chunks, add a new switch to the circuit (by adding one copy to a switch -/// gate). Returns the wire for the switch boolean, and the two output wire chunks. -fn create_switch_circuit, const D: usize>( - builder: &mut CircuitBuilder, - a1: Vec, - a2: Vec, -) -> (Target, Vec, Vec) { - assert_eq!(a1.len(), a2.len(), "Chunk size must be the same"); - - let chunk_size = a1.len(); - - let gate = SwitchGate::new_from_config(&builder.config, chunk_size); - let params = vec![F::from_canonical_usize(chunk_size)]; - let (row, next_copy) = builder.find_slot(gate, ¶ms, &[]); - - let mut c = Vec::new(); - let mut d = Vec::new(); - for e in 0..chunk_size { - builder.connect( - a1[e], - Target::wire(row, gate.wire_first_input(next_copy, e)), - ); - builder.connect( - a2[e], - Target::wire(row, gate.wire_second_input(next_copy, e)), - ); - c.push(Target::wire(row, gate.wire_first_output(next_copy, e))); - d.push(Target::wire(row, gate.wire_second_output(next_copy, e))); - } - - let switch = Target::wire(row, gate.wire_switch_bool(next_copy)); - - (switch, c, d) -} - -fn assert_permutation_helper_circuit, const D: usize>( - builder: &mut CircuitBuilder, - a: Vec>, - b: Vec>, -) { - assert_eq!( - a.len(), - b.len(), - "Permutation must have same number of inputs and outputs" - ); - assert_eq!(a[0].len(), b[0].len(), "Chunk size must be the same"); - - let n = a.len(); - let even = n % 2 == 0; - - let mut child_1_a = Vec::new(); - let mut child_1_b = Vec::new(); - let mut child_2_a = Vec::new(); - let mut child_2_b = Vec::new(); - - // See Figure 8 in the AS-Waksman paper. - let a_num_switches = n / 2; - let b_num_switches = if even { - a_num_switches - 1 - } else { - a_num_switches - }; - - let mut a_switches = Vec::new(); - let mut b_switches = Vec::new(); - for i in 0..a_num_switches { - let (switch, out_1, out_2) = - create_switch_circuit(builder, a[i * 2].clone(), a[i * 2 + 1].clone()); - a_switches.push(switch); - child_1_a.push(out_1); - child_2_a.push(out_2); - } - for i in 0..b_num_switches { - let (switch, out_1, out_2) = - create_switch_circuit(builder, b[i * 2].clone(), b[i * 2 + 1].clone()); - b_switches.push(switch); - child_1_b.push(out_1); - child_2_b.push(out_2); - } - - // See Figure 8 in the AS-Waksman paper. - if even { - child_1_b.push(b[n - 2].clone()); - child_2_b.push(b[n - 1].clone()); - } else { - child_2_a.push(a[n - 1].clone()); - child_2_b.push(b[n - 1].clone()); - } - - assert_permutation_circuit(builder, child_1_a, child_1_b); - assert_permutation_circuit(builder, child_2_a, child_2_b); - - builder.add_simple_generator(PermutationGenerator:: { - a, - b, - a_switches, - b_switches, - _phantom: PhantomData, - }); -} - -fn route( - a_values: Vec>, - b_values: Vec>, - a_switches: Vec, - b_switches: Vec, - witness: &PartitionWitness, - out_buffer: &mut GeneratedValues, -) { - assert_eq!(a_values.len(), b_values.len()); - let n = a_values.len(); - let even = n % 2 == 0; - - // We use a bimap to match indices of values in a to indices of the same values in b. - // This means that given a wire on one side, we can easily find the matching wire on the other side. - let ab_map = bimap_from_lists(a_values, b_values); - - let switches = [a_switches, b_switches]; - - // We keep track of the new wires we've routed (after routing some wires, we need to check `witness` - // and `newly_set` instead of just `witness`. - let mut newly_set = [vec![false; n], vec![false; n]]; - - // Given a side and an index, returns the index in the other side that corresponds to the same value. - let ab_map_by_side = |side: usize, index: usize| -> usize { - *match side { - 0 => ab_map.get_by_left(&index), - 1 => ab_map.get_by_right(&index), - _ => panic!("Expected side to be 0 or 1"), - } - .unwrap() - }; - - // We maintain two maps for wires which have been routed to a particular subnetwork on one side - // of the network (left or right) but not the other. The keys are wire indices, and the values - // are subnetwork indices. - let mut partial_routes = [BTreeMap::new(), BTreeMap::new()]; - - // After we route a wire on one side, we find the corresponding wire on the other side and check - // if it still needs to be routed. If so, we add it to partial_routes. - let enqueue_other_side = |partial_routes: &mut [BTreeMap], - witness: &PartitionWitness, - newly_set: &mut [Vec], - side: usize, - this_i: usize, - subnet: bool| { - let other_side = 1 - side; - let other_i = ab_map_by_side(side, this_i); - let other_switch_i = other_i / 2; - - if other_switch_i >= switches[other_side].len() { - // The other wire doesn't go through a switch, so there's no routing to be done. - // This happens in the case of the very last wire. - return; - } - - if witness.contains(switches[other_side][other_switch_i]) - || newly_set[other_side][other_switch_i] - { - // The other switch has already been routed. - return; - } - - let other_i_sibling = 4 * other_switch_i + 1 - other_i; - if let Some(&sibling_subnet) = partial_routes[other_side].get(&other_i_sibling) { - // The other switch's sibling is already pending routing. - assert_ne!(subnet, sibling_subnet); - } else { - let opt_old_subnet = partial_routes[other_side].insert(other_i, subnet); - if let Some(old_subnet) = opt_old_subnet { - assert_eq!(subnet, old_subnet, "Routing conflict (should never happen)"); - } - } - }; - - // See Figure 8 in the AS-Waksman paper. - if even { - enqueue_other_side( - &mut partial_routes, - witness, - &mut newly_set, - 1, - n - 2, - false, - ); - enqueue_other_side(&mut partial_routes, witness, &mut newly_set, 1, n - 1, true); - } else { - enqueue_other_side(&mut partial_routes, witness, &mut newly_set, 0, n - 1, true); - enqueue_other_side(&mut partial_routes, witness, &mut newly_set, 1, n - 1, true); - } - - let route_switch = |partial_routes: &mut [BTreeMap], - witness: &PartitionWitness, - out_buffer: &mut GeneratedValues, - newly_set: &mut [Vec], - side: usize, - switch_index: usize, - swap: bool| { - // First, we actually set the switch configuration. - out_buffer.set_target(switches[side][switch_index], F::from_bool(swap)); - newly_set[side][switch_index] = true; - - // Then, we enqueue the two corresponding wires on the other side of the network, to ensure - // that they get routed in the next step. - let this_i_1 = switch_index * 2; - let this_i_2 = this_i_1 + 1; - enqueue_other_side(partial_routes, witness, newly_set, side, this_i_1, swap); - enqueue_other_side(partial_routes, witness, newly_set, side, this_i_2, !swap); - }; - - // If {a,b}_only_routes is empty, then we can route any switch next. For efficiency, we will - // simply do top-down scans (one on the left side, one on the right side) for switches which - // have not yet been routed. These variables represent the positions of those two scans. - let mut scan_index = [0, 0]; - - // Until both scans complete, we alternate back and worth between the left and right switch - // layers. We process any partially routed wires for that side, or if there aren't any, we route - // the next switch in our scan. - while scan_index[0] < switches[0].len() || scan_index[1] < switches[1].len() { - for side in 0..=1 { - if !partial_routes[side].is_empty() { - for (this_i, subnet) in partial_routes[side].clone().into_iter() { - let this_first_switch_input = this_i % 2 == 0; - let swap = this_first_switch_input == subnet; - let this_switch_i = this_i / 2; - route_switch( - &mut partial_routes, - witness, - out_buffer, - &mut newly_set, - side, - this_switch_i, - swap, - ); - } - partial_routes[side].clear(); - } else { - // We can route any switch next. Continue our scan for pending switches. - while scan_index[side] < switches[side].len() - && (witness.contains(switches[side][scan_index[side]]) - || newly_set[side][scan_index[side]]) - { - scan_index[side] += 1; - } - if scan_index[side] < switches[side].len() { - // Either switch configuration would work; we arbitrarily choose to not swap. - route_switch( - &mut partial_routes, - witness, - out_buffer, - &mut newly_set, - side, - scan_index[side], - false, - ); - scan_index[side] += 1; - } - } - } - } -} - -#[derive(Debug)] -struct PermutationGenerator { - a: Vec>, - b: Vec>, - a_switches: Vec, - b_switches: Vec, - _phantom: PhantomData, -} - -impl SimpleGenerator for PermutationGenerator { - fn dependencies(&self) -> Vec { - self.a.iter().chain(&self.b).flatten().cloned().collect() - } - - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let a_values = self - .a - .iter() - .map(|chunk| chunk.iter().map(|wire| witness.get_target(*wire)).collect()) - .collect(); - let b_values = self - .b - .iter() - .map(|chunk| chunk.iter().map(|wire| witness.get_target(*wire)).collect()) - .collect(); - route( - a_values, - b_values, - self.a_switches.clone(), - self.b_switches.clone(), - witness, - out_buffer, - ); - } -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - use plonky2::field::types::{Field, Sample}; - use plonky2::iop::witness::PartialWitness; - use plonky2::plonk::circuit_data::CircuitConfig; - use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use rand::seq::SliceRandom; - use rand::{thread_rng, Rng}; - - use super::*; - - fn test_permutation_good(size: usize) -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_recursion_config(); - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let lst: Vec = (0..size * 2).map(F::from_canonical_usize).collect(); - let a: Vec> = lst[..] - .chunks(2) - .map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])]) - .collect(); - let mut b = a.clone(); - b.shuffle(&mut thread_rng()); - - assert_permutation_circuit(&mut builder, a, b); - - let data = builder.build::(); - let proof = data.prove(pw)?; - - data.verify(proof) - } - - fn test_permutation_duplicates(size: usize) -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_recursion_config(); - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let mut rng = thread_rng(); - let lst: Vec = (0..size * 2) - .map(|_| F::from_canonical_usize(rng.gen_range(0..2usize))) - .collect(); - let a: Vec> = lst[..] - .chunks(2) - .map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])]) - .collect(); - - let mut b = a.clone(); - b.shuffle(&mut thread_rng()); - - assert_permutation_circuit(&mut builder, a, b); - - let data = builder.build::(); - let proof = data.prove(pw)?; - - data.verify(proof) - } - - fn test_permutation_bad(size: usize) -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_recursion_config(); - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let lst1: Vec = F::rand_vec(size * 2); - let lst2: Vec = F::rand_vec(size * 2); - let a: Vec> = lst1[..] - .chunks(2) - .map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])]) - .collect(); - let b: Vec> = lst2[..] - .chunks(2) - .map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])]) - .collect(); - - assert_permutation_circuit(&mut builder, a, b); - - let data = builder.build::(); - data.prove(pw)?; - - Ok(()) - } - - #[test] - fn test_permutations_duplicates() -> Result<()> { - for n in 2..9 { - test_permutation_duplicates(n)?; - } - - Ok(()) - } - - #[test] - fn test_permutations_good() -> Result<()> { - for n in 2..9 { - test_permutation_good(n)?; - } - - Ok(()) - } - - #[test] - #[should_panic] - fn test_permutation_bad_small() { - let size = 2; - - test_permutation_bad(size).unwrap() - } - - #[test] - #[should_panic] - fn test_permutation_bad_medium() { - let size = 6; - - test_permutation_bad(size).unwrap() - } - - #[test] - #[should_panic] - fn test_permutation_bad_large() { - let size = 10; - - test_permutation_bad(size).unwrap() - } -} diff --git a/waksman/src/sorting.rs b/waksman/src/sorting.rs deleted file mode 100644 index 571a066b..00000000 --- a/waksman/src/sorting.rs +++ /dev/null @@ -1,277 +0,0 @@ -use std::marker::PhantomData; - -use itertools::izip; -use plonky2::field::extension::Extendable; -use plonky2::field::types::Field; -use plonky2::hash::hash_types::RichField; -use plonky2::iop::generator::{GeneratedValues, SimpleGenerator}; -use plonky2::iop::target::{BoolTarget, Target}; -use plonky2::iop::witness::{PartitionWitness, Witness, WitnessWrite}; -use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2_util::ceil_div_usize; - -use crate::gates::assert_le::AssertLessThanGate; -use crate::permutation::assert_permutation_circuit; - -pub struct MemoryOp { - is_write: bool, - address: F, - timestamp: F, - value: F, -} - -#[derive(Clone, Debug)] -pub struct MemoryOpTarget { - is_write: BoolTarget, - address: Target, - timestamp: Target, - value: Target, -} - -pub fn assert_permutation_memory_ops_circuit, const D: usize>( - builder: &mut CircuitBuilder, - a: &[MemoryOpTarget], - b: &[MemoryOpTarget], -) { - let a_chunks: Vec> = a - .iter() - .map(|op| vec![op.address, op.timestamp, op.is_write.target, op.value]) - .collect(); - let b_chunks: Vec> = b - .iter() - .map(|op| vec![op.address, op.timestamp, op.is_write.target, op.value]) - .collect(); - - assert_permutation_circuit(builder, a_chunks, b_chunks); -} - -/// Add an AssertLessThanGate to assert that `lhs` is less than `rhs`, where their values are at most `bits` bits. -pub fn assert_le_circuit, const D: usize>( - builder: &mut CircuitBuilder, - lhs: Target, - rhs: Target, - bits: usize, - num_chunks: usize, -) { - let gate = AssertLessThanGate::new(bits, num_chunks); - let row = builder.add_gate(gate.clone(), vec![]); - - builder.connect(Target::wire(row, gate.wire_first_input()), lhs); - builder.connect(Target::wire(row, gate.wire_second_input()), rhs); -} - -/// Sort memory operations by address value, then by timestamp value. -/// This is done by combining address and timestamp into one field element (using their given bit lengths). -pub fn sort_memory_ops_circuit, const D: usize>( - builder: &mut CircuitBuilder, - ops: &[MemoryOpTarget], - address_bits: usize, - timestamp_bits: usize, -) -> Vec { - let n = ops.len(); - - let combined_bits = address_bits + timestamp_bits; - let chunk_bits = 3; - let num_chunks = ceil_div_usize(combined_bits, chunk_bits); - - // This is safe because `assert_permutation` will force these targets (in the output list) to match the boolean values from the input list. - let is_write_targets: Vec<_> = builder - .add_virtual_targets(n) - .iter() - .map(|&t| BoolTarget::new_unsafe(t)) - .collect(); - - let address_targets = builder.add_virtual_targets(n); - let timestamp_targets = builder.add_virtual_targets(n); - let value_targets = builder.add_virtual_targets(n); - - let output_targets: Vec<_> = izip!( - is_write_targets, - address_targets, - timestamp_targets, - value_targets - ) - .map(|(i, a, t, v)| MemoryOpTarget { - is_write: i, - address: a, - timestamp: t, - value: v, - }) - .collect(); - - let two_n = builder.constant(F::from_canonical_usize(1 << timestamp_bits)); - let address_timestamp_combined: Vec<_> = output_targets - .iter() - .map(|op| builder.mul_add(op.address, two_n, op.timestamp)) - .collect(); - - for i in 1..n { - assert_le_circuit( - builder, - address_timestamp_combined[i - 1], - address_timestamp_combined[i], - combined_bits, - num_chunks, - ); - } - - assert_permutation_memory_ops_circuit(builder, ops, &output_targets); - - builder.add_simple_generator(MemoryOpSortGenerator:: { - input_ops: ops.to_vec(), - output_ops: output_targets.clone(), - _phantom: PhantomData, - }); - - output_targets -} - -#[derive(Debug)] -struct MemoryOpSortGenerator, const D: usize> { - input_ops: Vec, - output_ops: Vec, - _phantom: PhantomData, -} - -impl, const D: usize> SimpleGenerator - for MemoryOpSortGenerator -{ - fn dependencies(&self) -> Vec { - self.input_ops - .iter() - .flat_map(|op| vec![op.is_write.target, op.address, op.timestamp, op.value]) - .collect() - } - - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let n = self.input_ops.len(); - debug_assert!(self.output_ops.len() == n); - - let mut ops: Vec<_> = self - .input_ops - .iter() - .map(|op| { - let is_write = witness.get_bool_target(op.is_write); - let address = witness.get_target(op.address); - let timestamp = witness.get_target(op.timestamp); - let value = witness.get_target(op.value); - MemoryOp { - is_write, - address, - timestamp, - value, - } - }) - .collect(); - - ops.sort_unstable_by_key(|op| { - ( - op.address.to_canonical_u64(), - op.timestamp.to_canonical_u64(), - ) - }); - - for (op, out_op) in ops.iter().zip(&self.output_ops) { - out_buffer.set_target(out_op.is_write.target, F::from_bool(op.is_write)); - out_buffer.set_target(out_op.address, op.address); - out_buffer.set_target(out_op.timestamp, op.timestamp); - out_buffer.set_target(out_op.value, op.value); - } - } -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - use plonky2::field::types::{Field, PrimeField64, Sample}; - use plonky2::iop::witness::PartialWitness; - use plonky2::plonk::circuit_data::CircuitConfig; - use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use rand::{thread_rng, Rng}; - - use super::*; - - fn test_sorting(size: usize, address_bits: usize, timestamp_bits: usize) -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_recursion_config(); - - let mut pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let mut rng = thread_rng(); - let is_write_vals: Vec<_> = (0..size).map(|_| rng.gen_range(0..2) != 0).collect(); - let address_vals: Vec<_> = (0..size) - .map(|_| F::from_canonical_u64(rng.gen_range(0..1 << address_bits as u64))) - .collect(); - let timestamp_vals: Vec<_> = (0..size) - .map(|_| F::from_canonical_u64(rng.gen_range(0..1 << timestamp_bits as u64))) - .collect(); - let value_vals: Vec<_> = (0..size).map(|_| F::rand()).collect(); - - let input_ops: Vec = izip!( - is_write_vals.clone(), - address_vals.clone(), - timestamp_vals.clone(), - value_vals.clone() - ) - .map(|(is_write, address, timestamp, value)| MemoryOpTarget { - is_write: builder.constant_bool(is_write), - address: builder.constant(address), - timestamp: builder.constant(timestamp), - value: builder.constant(value), - }) - .collect(); - - let combined_vals_u64: Vec<_> = timestamp_vals - .iter() - .zip(&address_vals) - .map(|(&t, &a)| (a.to_canonical_u64() << timestamp_bits as u64) + t.to_canonical_u64()) - .collect(); - let mut input_ops_and_keys: Vec<_> = - izip!(is_write_vals, address_vals, timestamp_vals, value_vals) - .zip(combined_vals_u64) - .collect::>(); - input_ops_and_keys.sort_by_key(|(_, val)| *val); - let input_ops_sorted: Vec<_> = input_ops_and_keys.iter().map(|(x, _)| x).collect(); - - let output_ops = sort_memory_ops_circuit( - &mut builder, - input_ops.as_slice(), - address_bits, - timestamp_bits, - ); - - for i in 0..size { - pw.set_bool_target(output_ops[i].is_write, input_ops_sorted[i].0); - pw.set_target(output_ops[i].address, input_ops_sorted[i].1); - pw.set_target(output_ops[i].timestamp, input_ops_sorted[i].2); - pw.set_target(output_ops[i].value, input_ops_sorted[i].3); - } - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - - data.verify(proof) - } - - #[test] - fn test_sorting_small() -> Result<()> { - let size = 5; - let address_bits = 20; - let timestamp_bits = 20; - - test_sorting(size, address_bits, timestamp_bits) - } - - #[test] - fn test_sorting_large() -> Result<()> { - let size = 20; - let address_bits = 20; - let timestamp_bits = 20; - - test_sorting(size, address_bits, timestamp_bits) - } -}