diff --git a/plonky2/src/gadgets/arithmetic.rs b/plonky2/src/gadgets/arithmetic.rs index 11fc57bf..3c43d97a 100644 --- a/plonky2/src/gadgets/arithmetic.rs +++ b/plonky2/src/gadgets/arithmetic.rs @@ -6,7 +6,9 @@ use plonky2_field::field_types::Field64; use crate::gates::arithmetic_base::ArithmeticGate; use crate::gates::exponentiation::ExponentiationGate; use crate::hash::hash_types::RichField; +use crate::iop::generator::{SimpleGenerator, GeneratedValues}; use crate::iop::target::{BoolTarget, Target}; +use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; impl, const D: usize> CircuitBuilder { @@ -323,6 +325,44 @@ impl, const D: usize> CircuitBuilder { let res = self.sub(one, b.target); BoolTarget::new_unsafe(res) } + + pub fn is_equal(&mut self, x: Target, y: Target) -> BoolTarget { + let b = self.add_virtual_bool_target(); + self.add_simple_generator(EqualityGenerator { + x, + y, + b, + }); + + let diff = self.sub(x, y); + let result = self.mul(b.target, diff); + + let zero = self.zero(); + self.connect(zero, result); + + b + } +} + +#[derive(Debug)] +struct EqualityGenerator { + x: Target, + y: Target, + b: BoolTarget, +} + +impl SimpleGenerator for EqualityGenerator +{ + fn dependencies(&self) -> Vec { + vec![self.x, self.y] + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let x = witness.get_target(self.x); + let y = witness.get_target(self.y); + + out_buffer.set_bool_target(self.b, x == y); + } } /// Represents a base arithmetic operation in the circuit. Used to memoize results. diff --git a/plonky2/src/gadgets/curve.rs b/plonky2/src/gadgets/curve.rs index 83309196..d2a298a8 100644 --- a/plonky2/src/gadgets/curve.rs +++ b/plonky2/src/gadgets/curve.rs @@ -124,14 +124,19 @@ impl, const D: usize> CircuitBuilder { p2: &AffinePointTarget, b: BoolTarget, ) -> AffinePointTarget { - let to_add_x = self.mul_nonnative_by_bool(&p2.x, b); - let to_add_y = self.mul_nonnative_by_bool(&p2.y, b); - let sum_x = self.add_nonnative(&p1.x, &to_add_x); - let sum_y = self.add_nonnative(&p1.y, &to_add_y); + let not_b = self.not(b); + let sum = self.curve_add(p1, p2); + let x_if_true = self.mul_nonnative_by_bool(&sum.x, b); + let y_if_true = self.mul_nonnative_by_bool(&sum.y, b); + let x_if_false = self.mul_nonnative_by_bool(&p1.x, not_b); + let y_if_false = self.mul_nonnative_by_bool(&p1.y, not_b); + + let x = self.add_nonnative(&x_if_true, &x_if_false); + let y = self.add_nonnative(&y_if_true, &y_if_false); AffinePointTarget { - x: sum_x, - y: sum_y, + x, + y, } }