fixed is_equal

This commit is contained in:
Nicholas Ward 2022-02-11 12:54:41 -08:00
parent 3787f3be22
commit ad1aa4ae10

View File

@ -327,16 +327,23 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
pub fn is_equal(&mut self, x: Target, y: Target) -> BoolTarget {
let b = self.add_virtual_bool_target();
self.add_simple_generator(EqualityGenerator { x, y, b });
let zero = self.zero();
let equal = self.add_virtual_bool_target();
let not_equal = self.not(equal);
let inv = self.add_virtual_target();
self.add_simple_generator(EqualityGenerator { x, y, equal, inv });
let diff = self.sub(x, y);
let result = self.mul(b.target, diff);
let not_equal_check = self.mul(equal.target, diff);
let zero = self.zero();
self.connect(zero, result);
let diff_normalized = self.mul(diff, inv);
let equal_check = self.sub(diff_normalized, not_equal.target);
b
self.connect(not_equal_check, zero);
self.connect(equal_check, zero);
equal
}
}
@ -344,7 +351,8 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
struct EqualityGenerator {
x: Target,
y: Target,
b: BoolTarget,
equal: BoolTarget,
inv: Target,
}
impl<F: RichField> SimpleGenerator<F> for EqualityGenerator {
@ -356,7 +364,14 @@ impl<F: RichField> SimpleGenerator<F> for EqualityGenerator {
let x = witness.get_target(self.x);
let y = witness.get_target(self.y);
out_buffer.set_bool_target(self.b, x == y);
let inv = if x != y {
(x - y).inverse()
} else {
F::ZERO
};
out_buffer.set_bool_target(self.equal, x == y);
out_buffer.set_target(self.inv, inv);
}
}