From 1a383b6260fd7b68aaade55af3a9edbd70e7f5f5 Mon Sep 17 00:00:00 2001 From: oskarth Date: Thu, 13 Jan 2022 19:30:00 +0800 Subject: [PATCH] Ensure Circom 1 tests pass with experimental Circom 2 support (#18) * All tests pass under circom-2 feature flag - Check for version in WASM, default to version 1 - Include Circom1 when Circom 2 feature flag is enabled Currently a lot of code duplication. Once Circom-2 is more stable and proven to work in the wild, feature flag can be removed. * Separate Circom 1 and Circom2 witness calculation * Cleanup * WitnessCalculator helpers for Circom 1 and 2 Also make helper fn private * Move comment * Fix expression return * cargo fmt * Add cargo test circom-2 to ci --- .github/workflows/ci.yml | 5 + Cargo.lock | 47 +++++++ src/witness/circom.rs | 16 ++- src/witness/mod.rs | 1 - src/witness/witness_calculator.rs | 208 ++++++++++++++++++++---------- 5 files changed, 200 insertions(+), 77 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 43d2848..1edc363 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -44,6 +44,11 @@ jobs: export PATH=$HOME/bin:$PATH cargo test + - name: cargo test circom 2 feature flag + run: | + export PATH=$HOME/bin:$PATH + cargo test --features circom-2 + lint: runs-on: ubuntu-latest steps: diff --git a/Cargo.lock b/Cargo.lock index 562eec3..dc546f7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -97,6 +97,7 @@ dependencies = [ "fnv", "hex", "hex-literal", + "num", "num-bigint", "num-traits", "serde_json", @@ -1950,6 +1951,20 @@ dependencies = [ "winapi", ] +[[package]] +name = "num" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43db66d1170d347f9a065114077f7dccb00c1b9478c89384490a3425279a4606" +dependencies = [ + "num-bigint", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", +] + [[package]] name = "num-bigint" version = "0.4.0" @@ -1962,6 +1977,15 @@ dependencies = [ "rand", ] +[[package]] +name = "num-complex" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26873667bbbb7c5182d4a37c1add32cdf09f841af72da53318fdb81543c15085" +dependencies = [ + "num-traits", +] + [[package]] name = "num-integer" version = "0.1.44" @@ -1972,6 +1996,29 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-iter" +version = "0.1.42" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2021c8337a54d21aca0d59a92577a029af9431cb59b909b03252b9c164fad59" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d41702bd167c2df5520b384281bc111a4b5efcf7fbc4c9c222c815b07e0a6a6a" +dependencies = [ + "autocfg", + "num-bigint", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.14" diff --git a/src/witness/circom.rs b/src/witness/circom.rs index 4b7a4e3..f72563e 100644 --- a/src/witness/circom.rs +++ b/src/witness/circom.rs @@ -19,6 +19,8 @@ pub trait CircomBase { ) -> Result<()>; fn set_signal(&self, c_idx: i32, component: i32, signal: i32, p_val: i32) -> Result<()>; fn get_i32(&self, name: &str) -> Result; + // Only exists natively in Circom2, hardcoded for Circom + fn get_version(&self) -> Result; } pub trait Circom { @@ -27,7 +29,6 @@ pub trait Circom { } pub trait Circom2 { - fn get_version(&self) -> Result; fn get_field_num_len32(&self) -> Result; fn get_raw_prime(&self) -> Result<()>; fn read_shared_rw_memory(&self, i: i32) -> Result; @@ -37,7 +38,6 @@ pub trait Circom2 { fn get_witness_size(&self) -> Result; } -#[cfg(not(feature = "circom-2"))] impl Circom for Wasm { fn get_fr_len(&self) -> Result { self.get_i32("getFrLen") @@ -50,10 +50,6 @@ impl Circom for Wasm { #[cfg(feature = "circom-2")] impl Circom2 for Wasm { - fn get_version(&self) -> Result { - self.get_i32("getVersion") - } - fn get_field_num_len32(&self) -> Result { self.get_i32("getFieldNumLen32") } @@ -142,6 +138,14 @@ impl CircomBase for Wasm { Ok(()) } + // Default to version 1 if it isn't explicitly defined + fn get_version(&self) -> Result { + match self.0.exports.get_function("getVersion") { + Ok(func) => Ok(func.call(&[])?[0].unwrap_i32()), + Err(_) => Ok(1), + } + } + fn get_i32(&self, name: &str) -> Result { let func = self.func(name); let result = func.call(&[])?; diff --git a/src/witness/mod.rs b/src/witness/mod.rs index 2a2cb37..51708aa 100644 --- a/src/witness/mod.rs +++ b/src/witness/mod.rs @@ -10,7 +10,6 @@ pub(super) use circom::{CircomBase, Wasm}; #[cfg(feature = "circom-2")] pub(super) use circom::Circom2; -#[cfg(not(feature = "circom-2"))] pub(super) use circom::Circom; use fnv::FnvHasher; diff --git a/src/witness/witness_calculator.rs b/src/witness/witness_calculator.rs index 032ffb7..c5cad64 100644 --- a/src/witness/witness_calculator.rs +++ b/src/witness/witness_calculator.rs @@ -11,7 +11,6 @@ use num::ToPrimitive; #[cfg(feature = "circom-2")] use super::Circom2; -#[cfg(not(feature = "circom-2"))] use super::Circom; #[derive(Clone, Debug)] @@ -19,6 +18,7 @@ pub struct WitnessCalculator { pub instance: Wasm, pub memory: SafeMemory, pub n64: i32, + pub circom_version: i32, } // Error type to signal end of execution. @@ -77,36 +77,73 @@ impl WitnessCalculator { }; let instance = Wasm::new(Instance::new(&module, &import_object)?); - cfg_if::cfg_if! { - if #[cfg(feature = "circom-2")] { - //let version = instance.get_version()?; - let n32 = instance.get_field_num_len32()?; - let mut safe_memory = SafeMemory::new(memory, n32 as usize, BigInt::zero()); - instance.get_raw_prime()?; - let mut arr = vec![0; n32 as usize]; - for i in 0..n32 { - let res = instance.read_shared_rw_memory(i)?; - arr[(n32 as usize) - (i as usize) - 1] = res; - } - let prime = from_array32(arr); - } else { - // Fallback to Circom 1 behavior - //version = 1; - let n32 = (instance.get_fr_len()? >> 2) - 2; - let mut safe_memory = SafeMemory::new(memory, n32 as usize, BigInt::zero()); - let ptr = instance.get_ptr_raw_prime()?; - let prime = safe_memory.read_big(ptr as usize, n32 as usize)?; - } + let version; + + match instance.get_version() { + Ok(v) => version = v, + Err(_) => version = 1, } - let n64 = ((prime.bits() - 1) / 64 + 1) as i32; - safe_memory.prime = prime; + // Circom 2 feature flag with version 2 + #[cfg(feature = "circom-2")] + fn new_circom2(instance: Wasm, memory: Memory, version: i32) -> Result { + let n32 = instance.get_field_num_len32()?; + let mut safe_memory = SafeMemory::new(memory, n32 as usize, BigInt::zero()); + instance.get_raw_prime()?; + let mut arr = vec![0; n32 as usize]; + for i in 0..n32 { + let res = instance.read_shared_rw_memory(i)?; + arr[(n32 as usize) - (i as usize) - 1] = res; + } + let prime = from_array32(arr); - Ok(WitnessCalculator { - instance, - memory: safe_memory, - n64, - }) + let n64 = ((prime.bits() - 1) / 64 + 1) as i32; + safe_memory.prime = prime; + + Ok(WitnessCalculator { + instance, + memory: safe_memory, + n64, + circom_version: version, + }) + } + + fn new_circom1(instance: Wasm, memory: Memory, version: i32) -> Result { + // Fallback to Circom 1 behavior + let n32 = (instance.get_fr_len()? >> 2) - 2; + let mut safe_memory = SafeMemory::new(memory, n32 as usize, BigInt::zero()); + let ptr = instance.get_ptr_raw_prime()?; + let prime = safe_memory.read_big(ptr as usize, n32 as usize)?; + + let n64 = ((prime.bits() - 1) / 64 + 1) as i32; + safe_memory.prime = prime; + + Ok(WitnessCalculator { + instance, + memory: safe_memory, + n64, + circom_version: version, + }) + } + + // Three possibilities: + // a) Circom 2 feature flag enabled, WASM runtime version 2 + // b) Circom 2 feature flag enabled, WASM runtime version 1 + // c) Circom 1 default behavior + // + // Once Circom 2 support is more stable, feature flag can be removed + + cfg_if::cfg_if! { + if #[cfg(feature = "circom-2")] { + match version { + 2 => new_circom2(instance, memory, version), + 1 => new_circom1(instance, memory, version), + _ => panic!("Unknown Circom version") + } + } else { + new_circom1(instance, memory, version) + } + } } pub fn calculate_witness)>>( @@ -118,68 +155,99 @@ impl WitnessCalculator { cfg_if::cfg_if! { if #[cfg(feature = "circom-2")] { - let n32 = self.instance.get_field_num_len32()?; + match self.circom_version { + 2 => self.calculate_witness_circom2(inputs, sanity_check), + 1 => self.calculate_witness_circom1(inputs, sanity_check), + _ => panic!("Unknown Circom version") + } } else { - let old_mem_free_pos = self.memory.free_pos(); - let p_sig_offset = self.memory.alloc_u32(); - let p_fr = self.memory.alloc_fr(); + self.calculate_witness_circom1(inputs, sanity_check) } } + } + + // Circom 1 default behavior + fn calculate_witness_circom1)>>( + &mut self, + inputs: I, + sanity_check: bool, + ) -> Result> { + self.instance.init(sanity_check)?; + + let old_mem_free_pos = self.memory.free_pos(); + let p_sig_offset = self.memory.alloc_u32(); + let p_fr = self.memory.alloc_fr(); // allocate the inputs for (name, values) in inputs.into_iter() { let (msb, lsb) = fnv(&name); - cfg_if::cfg_if! { - if #[cfg(feature = "circom-2")] { - for (i, value) in values.into_iter().enumerate() { - let f_arr = to_array32(&value, n32 as usize); - for j in 0..n32 { - self.instance.write_shared_rw_memory(j as i32, f_arr[(n32 as usize) - 1 - (j as usize)])?; - } - self.instance.set_input_signal(msb as i32, lsb as i32, i as i32)?; - } - } else { - self.instance - .get_signal_offset32(p_sig_offset, 0, msb, lsb)?; + self.instance + .get_signal_offset32(p_sig_offset, 0, msb, lsb)?; - let sig_offset = self.memory.read_u32(p_sig_offset as usize) as usize; + let sig_offset = self.memory.read_u32(p_sig_offset as usize) as usize; - for (i, value) in values.into_iter().enumerate() { - self.memory.write_fr(p_fr as usize, &value)?; - self.instance - .set_signal(0, 0, (sig_offset + i) as i32, p_fr as i32)?; - } - } + for (i, value) in values.into_iter().enumerate() { + self.memory.write_fr(p_fr as usize, &value)?; + self.instance + .set_signal(0, 0, (sig_offset + i) as i32, p_fr as i32)?; } } let mut w = Vec::new(); - cfg_if::cfg_if! { - if #[cfg(feature = "circom-2")] { - let witness_size = self.instance.get_witness_size()?; - for i in 0..witness_size { - self.instance.get_witness(i)?; - let mut arr = vec![0; n32 as usize]; - for j in 0..n32 { - arr[(n32 as usize) - 1- (j as usize)] = self.instance.read_shared_rw_memory(j)?; - } - w.push(from_array32(arr)); - } + let n_vars = self.instance.get_n_vars()?; + for i in 0..n_vars { + let ptr = self.instance.get_ptr_witness(i)? as usize; + let el = self.memory.read_fr(ptr)?; + w.push(el); + } - } else { - let n_vars = self.instance.get_n_vars()?; - for i in 0..n_vars { - let ptr = self.instance.get_ptr_witness(i)? as usize; - let el = self.memory.read_fr(ptr)?; - w.push(el); - } + self.memory.set_free_pos(old_mem_free_pos); - self.memory.set_free_pos(old_mem_free_pos); + Ok(w) + } + + // Circom 2 feature flag with version 2 + #[cfg(feature = "circom-2")] + fn calculate_witness_circom2)>>( + &mut self, + inputs: I, + sanity_check: bool, + ) -> Result> { + self.instance.init(sanity_check)?; + + let n32 = self.instance.get_field_num_len32()?; + + // allocate the inputs + for (name, values) in inputs.into_iter() { + let (msb, lsb) = fnv(&name); + + for (i, value) in values.into_iter().enumerate() { + let f_arr = to_array32(&value, n32 as usize); + for j in 0..n32 { + self.instance.write_shared_rw_memory( + j as i32, + f_arr[(n32 as usize) - 1 - (j as usize)], + )?; + } + self.instance + .set_input_signal(msb as i32, lsb as i32, i as i32)?; } } + let mut w = Vec::new(); + + let witness_size = self.instance.get_witness_size()?; + for i in 0..witness_size { + self.instance.get_witness(i)?; + let mut arr = vec![0; n32 as usize]; + for j in 0..n32 { + arr[(n32 as usize) - 1 - (j as usize)] = self.instance.read_shared_rw_memory(j)?; + } + w.push(from_array32(arr)); + } + Ok(w) }