intermediate wires

This commit is contained in:
Nicholas Ward 2021-09-16 11:16:32 -07:00
parent 7abf48cd07
commit 8681cdec54

View File

@ -65,6 +65,11 @@ impl<F: RichField + Extendable<D>, const D: usize> ComparisonGate<F, D> {
debug_assert!(chunk < self.num_chunks);
3 + 3 * self.num_chunks + chunk
}
pub fn wire_intermediate_value(&self, chunk: usize) -> usize {
debug_assert!(chunk < self.num_chunks);
3 + 4 * self.num_chunks + chunk
}
}
impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
@ -122,8 +127,10 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate
constraints.push(chunks_equal * difference);
// Update `most_significant_diff_so_far`.
most_significant_diff_so_far = chunks_equal * most_significant_diff_so_far
+ (F::Extension::ONE - chunks_equal) * difference;
let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)];
constraints.push(intermediate_value - chunks_equal * most_significant_diff_so_far);
most_significant_diff_so_far =
intermediate_value + (F::Extension::ONE - chunks_equal) * difference;
}
let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()];
@ -188,8 +195,10 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate
constraints.push(chunks_equal * difference);
// Update `most_significant_diff_so_far`.
let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)];
constraints.push(intermediate_value - chunks_equal * most_significant_diff_so_far);
most_significant_diff_so_far =
chunks_equal * most_significant_diff_so_far + (F::ONE - chunks_equal) * difference;
intermediate_value + (F::ONE - chunks_equal) * difference;
}
let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()];
@ -262,10 +271,13 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate
constraints.push(builder.mul_extension(chunks_equal, difference));
// Update `most_significant_diff_so_far`.
let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)];
let old_diff = builder.mul_extension(chunks_equal, most_significant_diff_so_far);
constraints.push(builder.sub_extension(intermediate_value, old_diff));
let not_equal = builder.sub_extension(one, chunks_equal);
let new_diff = builder.mul_extension(not_equal, difference);
most_significant_diff_so_far = builder.add_extension(old_diff, new_diff);
most_significant_diff_so_far = builder.add_extension(intermediate_value, new_diff);
}
let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()];
@ -297,7 +309,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate
}
fn num_wires(&self) -> usize {
self.wire_chunks_equal(self.num_chunks - 1) + 1
self.wire_intermediate_value(self.num_chunks - 1) + 1
}
fn num_constants(&self) -> usize {
@ -305,11 +317,11 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate
}
fn degree(&self) -> usize {
(self.num_chunks + 1).max(1 << self.chunk_bits())
1 << self.chunk_bits()
}
fn num_constraints(&self) -> usize {
4 + 4 * self.num_chunks
4 + 5 * self.num_chunks
}
}
@ -372,15 +384,17 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
.map(|(&f, &s)| if f == s { F::ONE } else { F::ONE / (s - f) })
.collect();
let mut diff_index = 0;
let mut most_significant_diff_so_far = F::ZERO;
let mut intermediate_values = Vec::new();
for i in 1..self.gate.num_chunks {
if first_input_chunks[i] != second_input_chunks[i] {
diff_index = i;
most_significant_diff_so_far = second_input_chunks[i] - first_input_chunks[i];
intermediate_values.push(F::ZERO);
} else {
intermediate_values.push(most_significant_diff_so_far);
}
}
let most_significant_diff =
second_input_chunks[diff_index] - first_input_chunks[diff_index];
let most_significant_diff = most_significant_diff_so_far;
out_buffer.set_wire(
local_wire(self.gate.wire_most_significant_diff()),
@ -400,6 +414,10 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
equality_dummies[i],
);
out_buffer.set_wire(local_wire(self.gate.wire_chunks_equal(i)), chunks_equal[i]);
out_buffer.set_wire(
local_wire(self.gate.wire_intermediate_value(i)),
intermediate_values[i],
);
}
}
}
@ -444,6 +462,8 @@ mod tests {
assert_eq!(gate.wire_equality_dummy(4), 17);
assert_eq!(gate.wire_chunks_equal(0), 18);
assert_eq!(gate.wire_chunks_equal(4), 22);
assert_eq!(gate.wire_intermediate_value(0), 23);
assert_eq!(gate.wire_intermediate_value(4), 27);
}
#[test]
@ -504,15 +524,17 @@ mod tests {
.map(|(&f, &s)| if f == s { F::ONE } else { F::ONE / (s - f) })
.collect();
let mut diff_index = 0;
for i in 1..num_chunks {
let mut most_significant_diff_so_far = F::ZERO;
let mut intermediate_values = Vec::new();
for i in 0..num_chunks {
if first_input_chunks[i] != second_input_chunks[i] {
diff_index = i;
most_significant_diff_so_far = second_input_chunks[i] - first_input_chunks[i];
intermediate_values.push(F::ZERO);
} else {
intermediate_values.push(most_significant_diff_so_far);
}
}
let most_significant_diff =
second_input_chunks[diff_index] - first_input_chunks[diff_index];
let most_significant_diff = most_significant_diff_so_far;
v.push(first_input);
v.push(second_input);
@ -521,6 +543,7 @@ mod tests {
v.append(&mut second_input_chunks);
v.append(&mut equality_dummies);
v.append(&mut chunks_equal);
v.append(&mut intermediate_values);
v.iter().map(|&x| x.into()).collect::<Vec<_>>()
};