diff --git a/Cargo.toml b/Cargo.toml index 59b97ab..2cc83dd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,12 +32,13 @@ thiserror = "1.0.26" color-eyre = "0.5" criterion = "0.3.4" +cfg-if = "1.0" + [dev-dependencies] hex-literal = "0.2.1" tokio = { version = "1.7.1", features = ["macros"] } serde_json = "1.0.64" ethers = { git = "https://github.com/gakonst/ethers-rs", features = ["abigen"] } -cfg-if = "1.0" [[bench]] name = "groth16" @@ -45,3 +46,4 @@ harness = false [features] bench-complex-all = [] +circom-2 = [] diff --git a/src/witness/circom.rs b/src/witness/circom.rs index d9e9dbb..ae31b72 100644 --- a/src/witness/circom.rs +++ b/src/witness/circom.rs @@ -4,41 +4,92 @@ use wasmer::{Function, Instance, Value}; #[derive(Clone, Debug)] pub struct Wasm(Instance); -impl Wasm { - pub fn new(instance: Instance) -> Self { - Self(instance) +pub trait CircomBase { + fn init(&self, sanity_check: bool) -> Result<()>; + fn func(&self, name: &str) -> &Function; + fn get_ptr_witness_buffer(&self) -> Result; + fn get_ptr_witness(&self, w: i32) -> Result; + fn get_n_vars(&self) -> Result; + fn get_signal_offset32( + &self, + p_sig_offset: u32, + component: u32, + hash_msb: u32, + hash_lsb: u32, + ) -> Result<()>; + fn set_signal(&self, c_idx: i32, component: i32, signal: i32, p_val: i32) -> Result<()>; + fn get_i32(&self, name: &str) -> Result; +} + +pub trait Circom { + fn get_fr_len(&self) -> Result; + fn get_ptr_raw_prime(&self) -> Result; +} + +pub trait Circom2 { + fn get_version(&self) -> Result; + fn get_field_num_len32(&self) -> Result; + fn get_raw_prime(&self) -> Result<()>; + fn read_shared_rw_memory(&self, i: i32) -> Result; +} + +#[cfg(not(feature = "circom-2"))] +impl Circom for Wasm { + fn get_fr_len(&self) -> Result { + self.get_i32("getFrLen") } - pub fn init(&self, sanity_check: bool) -> Result<()> { + fn get_ptr_raw_prime(&self) -> Result { + self.get_i32("getPRawPrime") + } +} + +#[cfg(feature = "circom-2")] +impl Circom2 for Wasm { + fn get_version(&self) -> Result { + self.get_i32("getVersion") + } + + fn get_field_num_len32(&self) -> Result { + self.get_i32("getFieldNumLen32") + } + + fn get_raw_prime(&self) -> Result<()> { + let func = self.func("getRawPrime"); + let _result = func.call(&[])?; + Ok(()) + } + + fn read_shared_rw_memory(&self, i: i32) -> Result { + let func = self.func("readSharedRWMemory"); + let result = func.call(&[i.into()])?; + Ok(result[0].unwrap_i32()) + } +} + +impl CircomBase for Wasm { + fn init(&self, sanity_check: bool) -> Result<()> { let func = self.func("init"); func.call(&[Value::I32(sanity_check as i32)])?; Ok(()) } - pub fn get_fr_len(&self) -> Result { - self.get_i32("getFrLen") - } - - pub fn get_ptr_raw_prime(&self) -> Result { - self.get_i32("getPRawPrime") - } - - pub fn get_n_vars(&self) -> Result { - self.get_i32("getNVars") - } - - pub fn get_ptr_witness_buffer(&self) -> Result { + fn get_ptr_witness_buffer(&self) -> Result { self.get_i32("getWitnessBuffer") } - pub fn get_ptr_witness(&self, w: i32) -> Result { + fn get_ptr_witness(&self, w: i32) -> Result { let func = self.func("getPWitness"); let res = func.call(&[w.into()])?; Ok(res[0].unwrap_i32()) } - pub fn get_signal_offset32( + fn get_n_vars(&self) -> Result { + self.get_i32("getNVars") + } + + fn get_signal_offset32( &self, p_sig_offset: u32, component: u32, @@ -56,7 +107,7 @@ impl Wasm { Ok(()) } - pub fn set_signal(&self, c_idx: i32, component: i32, signal: i32, p_val: i32) -> Result<()> { + fn set_signal(&self, c_idx: i32, component: i32, signal: i32, p_val: i32) -> Result<()> { let func = self.func("setSignal"); func.call(&[c_idx.into(), component.into(), signal.into(), p_val.into()])?; @@ -76,3 +127,9 @@ impl Wasm { .unwrap_or_else(|_| panic!("function {} not found", name)) } } + +impl Wasm { + pub fn new(instance: Instance) -> Self { + Self(instance) + } +} diff --git a/src/witness/mod.rs b/src/witness/mod.rs index 6112a06..2a2cb37 100644 --- a/src/witness/mod.rs +++ b/src/witness/mod.rs @@ -5,7 +5,13 @@ mod memory; pub(super) use memory::SafeMemory; mod circom; -pub(super) use circom::Wasm; +pub(super) use circom::{CircomBase, Wasm}; + +#[cfg(feature = "circom-2")] +pub(super) use circom::Circom2; + +#[cfg(not(feature = "circom-2"))] +pub(super) use circom::Circom; use fnv::FnvHasher; use std::hash::Hasher; diff --git a/src/witness/witness_calculator.rs b/src/witness/witness_calculator.rs index 46fb6f2..54904da 100644 --- a/src/witness/witness_calculator.rs +++ b/src/witness/witness_calculator.rs @@ -1,10 +1,15 @@ +use super::{fnv, CircomBase, SafeMemory, Wasm}; use color_eyre::Result; use num_bigint::BigInt; use num_traits::Zero; use std::cell::Cell; use wasmer::{imports, Function, Instance, Memory, MemoryType, Module, RuntimeError, Store}; -use super::{fnv, SafeMemory, Wasm}; +#[cfg(feature = "circom-2")] +use super::Circom2; + +#[cfg(not(feature = "circom-2"))] +use super::Circom; #[derive(Clone, Debug)] pub struct WitnessCalculator { @@ -19,6 +24,16 @@ pub struct WitnessCalculator { #[error("{0}")] struct ExitCode(u32); +#[cfg(feature = "circom-2")] +fn from_array32(arr: Vec) -> BigInt { + let mut res = BigInt::zero(); + let radix = BigInt::from(0x100000000u64); + for &val in arr.iter() { + res = res * &radix + BigInt::from(val); + } + res +} + impl WitnessCalculator { pub fn new(path: impl AsRef) -> Result { let store = Store::default(); @@ -38,22 +53,44 @@ impl WitnessCalculator { "logFinishComponent" => runtime::log_component(&store), "logStartComponent" => runtime::log_component(&store), "log" => runtime::log_component(&store), + "exceptionHandler" => runtime::exception_handler(&store), + "showSharedRWMemory" => runtime::show_memory(&store), } }; let instance = Wasm::new(Instance::new(&module, &import_object)?); - let n32 = (instance.get_fr_len()? >> 2) - 2; + let n32; + let prime: BigInt; + let mut safe_memory: SafeMemory; - let mut memory = SafeMemory::new(memory, n32 as usize, BigInt::zero()); + cfg_if::cfg_if! { + if #[cfg(feature = "circom-2")] { + //let version = instance.get_version()?; + n32 = instance.get_field_num_len32()?; + safe_memory = SafeMemory::new(memory, n32 as usize, BigInt::zero()); + let _res = instance.get_raw_prime()?; + let mut arr = vec![0; n32 as usize]; + for i in 0..n32 { + let res = instance.read_shared_rw_memory(i)?; + arr[(n32 as usize) - (i as usize) - 1] = res; + } + prime = from_array32(arr); + } else { + // Fallback to Circom 1 behavior + //version = 1; + n32 = (instance.get_fr_len()? >> 2) - 2; + safe_memory = SafeMemory::new(memory, n32 as usize, BigInt::zero()); + let ptr = instance.get_ptr_raw_prime()?; + prime = safe_memory.read_big(ptr as usize, n32 as usize)?; + } + } - let ptr = instance.get_ptr_raw_prime()?; - let prime = memory.read_big(ptr as usize, n32 as usize)?; let n64 = ((prime.bits() - 1) / 64 + 1) as i32; - memory.prime = prime; + safe_memory.prime = prime; Ok(WitnessCalculator { instance, - memory, + memory: safe_memory, n64, }) } @@ -162,6 +199,20 @@ mod runtime { Function::new_native(store, func) } + // Circom 2.0 + pub fn exception_handler(store: &Store) -> Function { + #[allow(unused)] + fn func(a: i32) {} + Function::new_native(store, func) + } + + // Circom 2.0 + pub fn show_memory(store: &Store) -> Function { + #[allow(unused)] + fn func() {} + Function::new_native(store, func) + } + pub fn log_signal(store: &Store) -> Function { #[allow(unused)] fn func(a: i32, b: i32) {} diff --git a/test-vectors/circom2_multiplier2.r1cs b/test-vectors/circom2_multiplier2.r1cs new file mode 100644 index 0000000..e61b9b0 Binary files /dev/null and b/test-vectors/circom2_multiplier2.r1cs differ diff --git a/test-vectors/circom2_multiplier2.wasm b/test-vectors/circom2_multiplier2.wasm new file mode 100644 index 0000000..7ea487c Binary files /dev/null and b/test-vectors/circom2_multiplier2.wasm differ diff --git a/tests/groth16.rs b/tests/groth16.rs index c7eadd7..f4873ef 100644 --- a/tests/groth16.rs +++ b/tests/groth16.rs @@ -58,3 +58,35 @@ fn groth16_proof_wrong_input() { builder.build().unwrap_err(); } + +#[test] +#[cfg(feature = "circom-2")] +fn groth16_proof_circom2() -> Result<()> { + let cfg = CircomConfig::::new( + "./test-vectors/circom2_multiplier2.wasm", + "./test-vectors/circom2_multiplier2.r1cs", + )?; + let mut builder = CircomBuilder::new(cfg); + builder.push_input("a", 3); + builder.push_input("b", 11); + + // create an empty instance for setting it up + let circom = builder.setup(); + + let mut rng = thread_rng(); + let params = generate_random_parameters::(circom, &mut rng)?; + + let circom = builder.build()?; + + let inputs = circom.get_public_inputs().unwrap(); + + let proof = prove(circom, ¶ms, &mut rng)?; + + let pvk = prepare_verifying_key(¶ms.vk); + + let verified = verify_proof(&pvk, &proof, &inputs)?; + + assert!(verified); + + Ok(()) +}