diff --git a/src/witness/circom.rs b/src/witness/circom.rs index f72563e..f9f5aed 100644 --- a/src/witness/circom.rs +++ b/src/witness/circom.rs @@ -7,9 +7,9 @@ pub struct Wasm(Instance); pub trait CircomBase { fn init(&self, sanity_check: bool) -> Result<()>; fn func(&self, name: &str) -> &Function; - fn get_ptr_witness_buffer(&self) -> Result; - fn get_ptr_witness(&self, w: i32) -> Result; - fn get_n_vars(&self) -> Result; + fn get_ptr_witness_buffer(&self) -> Result; + fn get_ptr_witness(&self, w: u32) -> Result; + fn get_n_vars(&self) -> Result; fn get_signal_offset32( &self, p_sig_offset: u32, @@ -17,41 +17,41 @@ pub trait CircomBase { hash_msb: u32, hash_lsb: u32, ) -> Result<()>; - fn set_signal(&self, c_idx: i32, component: i32, signal: i32, p_val: i32) -> Result<()>; - fn get_i32(&self, name: &str) -> Result; + fn set_signal(&self, c_idx: u32, component: u32, signal: u32, p_val: u32) -> Result<()>; + fn get_u32(&self, name: &str) -> Result; // Only exists natively in Circom2, hardcoded for Circom - fn get_version(&self) -> Result; + fn get_version(&self) -> Result; } pub trait Circom { - fn get_fr_len(&self) -> Result; - fn get_ptr_raw_prime(&self) -> Result; + fn get_fr_len(&self) -> Result; + fn get_ptr_raw_prime(&self) -> Result; } pub trait Circom2 { - fn get_field_num_len32(&self) -> Result; + fn get_field_num_len32(&self) -> Result; fn get_raw_prime(&self) -> Result<()>; - fn read_shared_rw_memory(&self, i: i32) -> Result; - fn write_shared_rw_memory(&self, i: i32, v: i32) -> Result<()>; - fn set_input_signal(&self, hmsb: i32, hlsb: i32, pos: i32) -> Result<()>; - fn get_witness(&self, i: i32) -> Result<()>; - fn get_witness_size(&self) -> Result; + fn read_shared_rw_memory(&self, i: u32) -> Result; + fn write_shared_rw_memory(&self, i: u32, v: u32) -> Result<()>; + fn set_input_signal(&self, hmsb: u32, hlsb: u32, pos: u32) -> Result<()>; + fn get_witness(&self, i: u32) -> Result<()>; + fn get_witness_size(&self) -> Result; } impl Circom for Wasm { - fn get_fr_len(&self) -> Result { - self.get_i32("getFrLen") + fn get_fr_len(&self) -> Result { + self.get_u32("getFrLen") } - fn get_ptr_raw_prime(&self) -> Result { - self.get_i32("getPRawPrime") + fn get_ptr_raw_prime(&self) -> Result { + self.get_u32("getPRawPrime") } } #[cfg(feature = "circom-2")] impl Circom2 for Wasm { - fn get_field_num_len32(&self) -> Result { - self.get_i32("getFieldNumLen32") + fn get_field_num_len32(&self) -> Result { + self.get_u32("getFieldNumLen32") } fn get_raw_prime(&self) -> Result<()> { @@ -60,34 +60,32 @@ impl Circom2 for Wasm { Ok(()) } - fn read_shared_rw_memory(&self, i: i32) -> Result { + fn read_shared_rw_memory(&self, i: u32) -> Result { let func = self.func("readSharedRWMemory"); let result = func.call(&[i.into()])?; - Ok(result[0].unwrap_i32()) + Ok(result[0].unwrap_i32() as u32) } - fn write_shared_rw_memory(&self, i: i32, v: i32) -> Result<()> { + fn write_shared_rw_memory(&self, i: u32, v: u32) -> Result<()> { let func = self.func("writeSharedRWMemory"); func.call(&[i.into(), v.into()])?; Ok(()) } - fn set_input_signal(&self, hmsb: i32, hlsb: i32, pos: i32) -> Result<()> { + fn set_input_signal(&self, hmsb: u32, hlsb: u32, pos: u32) -> Result<()> { let func = self.func("setInputSignal"); func.call(&[hmsb.into(), hlsb.into(), pos.into()])?; Ok(()) } - fn get_witness(&self, i: i32) -> Result<()> { + fn get_witness(&self, i: u32) -> Result<()> { let func = self.func("getWitness"); func.call(&[i.into()])?; Ok(()) } - fn get_witness_size(&self) -> Result { - let func = self.func("getWitnessSize"); - let result = func.call(&[])?; - Ok(result[0].unwrap_i32()) + fn get_witness_size(&self) -> Result { + self.get_u32("getWitnessSize") } } @@ -98,19 +96,19 @@ impl CircomBase for Wasm { Ok(()) } - fn get_ptr_witness_buffer(&self) -> Result { - self.get_i32("getWitnessBuffer") + fn get_ptr_witness_buffer(&self) -> Result { + self.get_u32("getWitnessBuffer") } - fn get_ptr_witness(&self, w: i32) -> Result { + fn get_ptr_witness(&self, w: u32) -> Result { let func = self.func("getPWitness"); let res = func.call(&[w.into()])?; - Ok(res[0].unwrap_i32()) + Ok(res[0].unwrap_i32() as u32) } - fn get_n_vars(&self) -> Result { - self.get_i32("getNVars") + fn get_n_vars(&self) -> Result { + self.get_u32("getNVars") } fn get_signal_offset32( @@ -131,7 +129,7 @@ impl CircomBase for Wasm { Ok(()) } - fn set_signal(&self, c_idx: i32, component: i32, signal: i32, p_val: i32) -> Result<()> { + fn set_signal(&self, c_idx: u32, component: u32, signal: u32, p_val: u32) -> Result<()> { let func = self.func("setSignal"); func.call(&[c_idx.into(), component.into(), signal.into(), p_val.into()])?; @@ -139,17 +137,17 @@ impl CircomBase for Wasm { } // Default to version 1 if it isn't explicitly defined - fn get_version(&self) -> Result { + fn get_version(&self) -> Result { match self.0.exports.get_function("getVersion") { - Ok(func) => Ok(func.call(&[])?[0].unwrap_i32()), + Ok(func) => Ok(func.call(&[])?[0].unwrap_i32() as u32), Err(_) => Ok(1), } } - fn get_i32(&self, name: &str) -> Result { + fn get_u32(&self, name: &str) -> Result { let func = self.func(name); let result = func.call(&[])?; - Ok(result[0].unwrap_i32()) + Ok(result[0].unwrap_i32() as u32) } fn func(&self, name: &str) -> &Function { diff --git a/src/witness/witness_calculator.rs b/src/witness/witness_calculator.rs index c5cad64..6af0961 100644 --- a/src/witness/witness_calculator.rs +++ b/src/witness/witness_calculator.rs @@ -17,8 +17,8 @@ use super::Circom; pub struct WitnessCalculator { pub instance: Wasm, pub memory: SafeMemory, - pub n64: i32, - pub circom_version: i32, + pub n64: u32, + pub circom_version: u32, } // Error type to signal end of execution. @@ -28,7 +28,7 @@ pub struct WitnessCalculator { struct ExitCode(u32); #[cfg(feature = "circom-2")] -fn from_array32(arr: Vec) -> BigInt { +fn from_array32(arr: Vec) -> BigInt { let mut res = BigInt::zero(); let radix = BigInt::from(0x100000000u64); for &val in arr.iter() { @@ -38,15 +38,15 @@ fn from_array32(arr: Vec) -> BigInt { } #[cfg(feature = "circom-2")] -fn to_array32(s: &BigInt, size: usize) -> Vec { +fn to_array32(s: &BigInt, size: usize) -> Vec { let mut res = vec![0; size as usize]; let mut rem = s.clone(); let radix = BigInt::from(0x100000000u64); - let mut c = size - 1; + let mut c = size; while !rem.is_zero() { - res[c] = (&rem % &radix).to_i32().unwrap(); - rem /= &radix; c -= 1; + res[c] = (&rem % &radix).to_u32().unwrap(); + rem /= &radix; } res @@ -77,16 +77,11 @@ impl WitnessCalculator { }; let instance = Wasm::new(Instance::new(&module, &import_object)?); - let version; - - match instance.get_version() { - Ok(v) => version = v, - Err(_) => version = 1, - } + let version = instance.get_version().unwrap_or(1); // Circom 2 feature flag with version 2 #[cfg(feature = "circom-2")] - fn new_circom2(instance: Wasm, memory: Memory, version: i32) -> Result { + fn new_circom2(instance: Wasm, memory: Memory, version: u32) -> Result { let n32 = instance.get_field_num_len32()?; let mut safe_memory = SafeMemory::new(memory, n32 as usize, BigInt::zero()); instance.get_raw_prime()?; @@ -97,7 +92,7 @@ impl WitnessCalculator { } let prime = from_array32(arr); - let n64 = ((prime.bits() - 1) / 64 + 1) as i32; + let n64 = ((prime.bits() - 1) / 64 + 1) as u32; safe_memory.prime = prime; Ok(WitnessCalculator { @@ -108,14 +103,14 @@ impl WitnessCalculator { }) } - fn new_circom1(instance: Wasm, memory: Memory, version: i32) -> Result { + fn new_circom1(instance: Wasm, memory: Memory, version: u32) -> Result { // Fallback to Circom 1 behavior let n32 = (instance.get_fr_len()? >> 2) - 2; let mut safe_memory = SafeMemory::new(memory, n32 as usize, BigInt::zero()); let ptr = instance.get_ptr_raw_prime()?; let prime = safe_memory.read_big(ptr as usize, n32 as usize)?; - let n64 = ((prime.bits() - 1) / 64 + 1) as i32; + let n64 = ((prime.bits() - 1) / 64 + 1) as u32; safe_memory.prime = prime; Ok(WitnessCalculator { @@ -190,7 +185,7 @@ impl WitnessCalculator { for (i, value) in values.into_iter().enumerate() { self.memory.write_fr(p_fr as usize, &value)?; self.instance - .set_signal(0, 0, (sig_offset + i) as i32, p_fr as i32)?; + .set_signal(0, 0, (sig_offset + i) as u32, p_fr as u32)?; } } @@ -227,12 +222,12 @@ impl WitnessCalculator { let f_arr = to_array32(&value, n32 as usize); for j in 0..n32 { self.instance.write_shared_rw_memory( - j as i32, + j as u32, f_arr[(n32 as usize) - 1 - (j as usize)], )?; } self.instance - .set_input_signal(msb as i32, lsb as i32, i as i32)?; + .set_input_signal(msb as u32, lsb as u32, i as u32)?; } } diff --git a/tests/groth16.rs b/tests/groth16.rs index f4873ef..112dfbc 100644 --- a/tests/groth16.rs +++ b/tests/groth16.rs @@ -90,3 +90,19 @@ fn groth16_proof_circom2() -> Result<()> { Ok(()) } + +#[test] +#[cfg(feature = "circom-2")] +fn witness_generation_circom2() -> Result<()> { + let cfg = CircomConfig::::new( + "./test-vectors/circom2_multiplier2.wasm", + "./test-vectors/circom2_multiplier2.r1cs", + )?; + let mut builder = CircomBuilder::new(cfg); + builder.push_input("a", 3); + builder.push_input("b", 0x100000000u64 - 1); + + assert!(builder.build().is_ok()); + + Ok(()) +}