mirror of
https://github.com/logos-storage/circom-compat.git
synced 2026-01-07 15:33:13 +00:00
Cleanly separate Circom1 and Circom2 traits (#60)
This commit is contained in:
parent
464d868ff4
commit
24353a4225
@ -7,9 +7,15 @@ pub struct Wasm(Instance);
|
|||||||
pub trait CircomBase {
|
pub trait CircomBase {
|
||||||
fn init(&self, sanity_check: bool) -> Result<()>;
|
fn init(&self, sanity_check: bool) -> Result<()>;
|
||||||
fn func(&self, name: &str) -> &Function;
|
fn func(&self, name: &str) -> &Function;
|
||||||
fn get_ptr_witness_buffer(&self) -> Result<u32>;
|
|
||||||
fn get_ptr_witness(&self, w: u32) -> Result<u32>;
|
|
||||||
fn get_n_vars(&self) -> Result<u32>;
|
fn get_n_vars(&self) -> Result<u32>;
|
||||||
|
fn get_u32(&self, name: &str) -> Result<u32>;
|
||||||
|
// Only exists natively in Circom2, hardcoded for Circom
|
||||||
|
fn get_version(&self) -> Result<u32>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Circom1 {
|
||||||
|
fn get_ptr_witness(&self, w: u32) -> Result<u32>;
|
||||||
|
fn get_fr_len(&self) -> Result<u32>;
|
||||||
fn get_signal_offset32(
|
fn get_signal_offset32(
|
||||||
&self,
|
&self,
|
||||||
p_sig_offset: u32,
|
p_sig_offset: u32,
|
||||||
@ -18,13 +24,6 @@ pub trait CircomBase {
|
|||||||
hash_lsb: u32,
|
hash_lsb: u32,
|
||||||
) -> Result<()>;
|
) -> Result<()>;
|
||||||
fn set_signal(&self, c_idx: u32, component: u32, signal: u32, p_val: u32) -> Result<()>;
|
fn set_signal(&self, c_idx: u32, component: u32, signal: u32, p_val: u32) -> Result<()>;
|
||||||
fn get_u32(&self, name: &str) -> Result<u32>;
|
|
||||||
// Only exists natively in Circom2, hardcoded for Circom
|
|
||||||
fn get_version(&self) -> Result<u32>;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait Circom {
|
|
||||||
fn get_fr_len(&self) -> Result<u32>;
|
|
||||||
fn get_ptr_raw_prime(&self) -> Result<u32>;
|
fn get_ptr_raw_prime(&self) -> Result<u32>;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -38,7 +37,7 @@ pub trait Circom2 {
|
|||||||
fn get_witness_size(&self) -> Result<u32>;
|
fn get_witness_size(&self) -> Result<u32>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Circom for Wasm {
|
impl Circom1 for Wasm {
|
||||||
fn get_fr_len(&self) -> Result<u32> {
|
fn get_fr_len(&self) -> Result<u32> {
|
||||||
self.get_u32("getFrLen")
|
self.get_u32("getFrLen")
|
||||||
}
|
}
|
||||||
@ -46,6 +45,38 @@ impl Circom for Wasm {
|
|||||||
fn get_ptr_raw_prime(&self) -> Result<u32> {
|
fn get_ptr_raw_prime(&self) -> Result<u32> {
|
||||||
self.get_u32("getPRawPrime")
|
self.get_u32("getPRawPrime")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_ptr_witness(&self, w: u32) -> Result<u32> {
|
||||||
|
let func = self.func("getPWitness");
|
||||||
|
let res = func.call(&[w.into()])?;
|
||||||
|
|
||||||
|
Ok(res[0].unwrap_i32() as u32)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_signal_offset32(
|
||||||
|
&self,
|
||||||
|
p_sig_offset: u32,
|
||||||
|
component: u32,
|
||||||
|
hash_msb: u32,
|
||||||
|
hash_lsb: u32,
|
||||||
|
) -> Result<()> {
|
||||||
|
let func = self.func("getSignalOffset32");
|
||||||
|
func.call(&[
|
||||||
|
p_sig_offset.into(),
|
||||||
|
component.into(),
|
||||||
|
hash_msb.into(),
|
||||||
|
hash_lsb.into(),
|
||||||
|
])?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_signal(&self, c_idx: u32, component: u32, signal: u32, p_val: u32) -> Result<()> {
|
||||||
|
let func = self.func("setSignal");
|
||||||
|
func.call(&[c_idx.into(), component.into(), signal.into(), p_val.into()])?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "circom-2")]
|
#[cfg(feature = "circom-2")]
|
||||||
@ -96,46 +127,10 @@ impl CircomBase for Wasm {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_ptr_witness_buffer(&self) -> Result<u32> {
|
|
||||||
self.get_u32("getWitnessBuffer")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_ptr_witness(&self, w: u32) -> Result<u32> {
|
|
||||||
let func = self.func("getPWitness");
|
|
||||||
let res = func.call(&[w.into()])?;
|
|
||||||
|
|
||||||
Ok(res[0].unwrap_i32() as u32)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_n_vars(&self) -> Result<u32> {
|
fn get_n_vars(&self) -> Result<u32> {
|
||||||
self.get_u32("getNVars")
|
self.get_u32("getNVars")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_signal_offset32(
|
|
||||||
&self,
|
|
||||||
p_sig_offset: u32,
|
|
||||||
component: u32,
|
|
||||||
hash_msb: u32,
|
|
||||||
hash_lsb: u32,
|
|
||||||
) -> Result<()> {
|
|
||||||
let func = self.func("getSignalOffset32");
|
|
||||||
func.call(&[
|
|
||||||
p_sig_offset.into(),
|
|
||||||
component.into(),
|
|
||||||
hash_msb.into(),
|
|
||||||
hash_lsb.into(),
|
|
||||||
])?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_signal(&self, c_idx: u32, component: u32, signal: u32, p_val: u32) -> Result<()> {
|
|
||||||
let func = self.func("setSignal");
|
|
||||||
func.call(&[c_idx.into(), component.into(), signal.into(), p_val.into()])?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default to version 1 if it isn't explicitly defined
|
// Default to version 1 if it isn't explicitly defined
|
||||||
fn get_version(&self) -> Result<u32> {
|
fn get_version(&self) -> Result<u32> {
|
||||||
match self.0.exports.get_function("getVersion") {
|
match self.0.exports.get_function("getVersion") {
|
||||||
|
|||||||
@ -10,7 +10,7 @@ pub(super) use circom::{CircomBase, Wasm};
|
|||||||
#[cfg(feature = "circom-2")]
|
#[cfg(feature = "circom-2")]
|
||||||
pub(super) use circom::Circom2;
|
pub(super) use circom::Circom2;
|
||||||
|
|
||||||
pub(super) use circom::Circom;
|
pub(super) use circom::Circom1;
|
||||||
|
|
||||||
use fnv::FnvHasher;
|
use fnv::FnvHasher;
|
||||||
use std::hash::Hasher;
|
use std::hash::Hasher;
|
||||||
|
|||||||
@ -2,23 +2,22 @@ use super::{fnv, CircomBase, SafeMemory, Wasm};
|
|||||||
use color_eyre::Result;
|
use color_eyre::Result;
|
||||||
use num_bigint::BigInt;
|
use num_bigint::BigInt;
|
||||||
use num_traits::Zero;
|
use num_traits::Zero;
|
||||||
use std::cell::Cell;
|
|
||||||
use wasmer::{imports, Function, Instance, Memory, MemoryType, Module, RuntimeError, Store};
|
use wasmer::{imports, Function, Instance, Memory, MemoryType, Module, RuntimeError, Store};
|
||||||
|
|
||||||
#[cfg(feature = "circom-2")]
|
#[cfg(feature = "circom-2")]
|
||||||
use num::ToPrimitive;
|
use num::ToPrimitive;
|
||||||
|
|
||||||
|
use super::Circom1;
|
||||||
#[cfg(feature = "circom-2")]
|
#[cfg(feature = "circom-2")]
|
||||||
use super::Circom2;
|
use super::Circom2;
|
||||||
|
|
||||||
use super::Circom;
|
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct WitnessCalculator {
|
pub struct WitnessCalculator {
|
||||||
pub instance: Wasm,
|
pub instance: Wasm,
|
||||||
pub memory: SafeMemory,
|
pub memory: Option<SafeMemory>,
|
||||||
pub n64: u32,
|
pub n64: u32,
|
||||||
pub circom_version: u32,
|
pub circom_version: u32,
|
||||||
|
pub prime: BigInt,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Error type to signal end of execution.
|
// Error type to signal end of execution.
|
||||||
@ -98,9 +97,8 @@ impl WitnessCalculator {
|
|||||||
|
|
||||||
// Circom 2 feature flag with version 2
|
// Circom 2 feature flag with version 2
|
||||||
#[cfg(feature = "circom-2")]
|
#[cfg(feature = "circom-2")]
|
||||||
fn new_circom2(instance: Wasm, memory: Memory, version: u32) -> Result<WitnessCalculator> {
|
fn new_circom2(instance: Wasm, version: u32) -> Result<WitnessCalculator> {
|
||||||
let n32 = instance.get_field_num_len32()?;
|
let n32 = instance.get_field_num_len32()?;
|
||||||
let mut safe_memory = SafeMemory::new(memory, n32 as usize, BigInt::zero());
|
|
||||||
instance.get_raw_prime()?;
|
instance.get_raw_prime()?;
|
||||||
let mut arr = vec![0; n32 as usize];
|
let mut arr = vec![0; n32 as usize];
|
||||||
for i in 0..n32 {
|
for i in 0..n32 {
|
||||||
@ -110,13 +108,13 @@ impl WitnessCalculator {
|
|||||||
let prime = from_array32(arr);
|
let prime = from_array32(arr);
|
||||||
|
|
||||||
let n64 = ((prime.bits() - 1) / 64 + 1) as u32;
|
let n64 = ((prime.bits() - 1) / 64 + 1) as u32;
|
||||||
safe_memory.prime = prime;
|
|
||||||
|
|
||||||
Ok(WitnessCalculator {
|
Ok(WitnessCalculator {
|
||||||
instance,
|
instance,
|
||||||
memory: safe_memory,
|
memory: None,
|
||||||
n64,
|
n64,
|
||||||
circom_version: version,
|
circom_version: version,
|
||||||
|
prime,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -128,13 +126,14 @@ impl WitnessCalculator {
|
|||||||
let prime = safe_memory.read_big(ptr as usize, n32 as usize)?;
|
let prime = safe_memory.read_big(ptr as usize, n32 as usize)?;
|
||||||
|
|
||||||
let n64 = ((prime.bits() - 1) / 64 + 1) as u32;
|
let n64 = ((prime.bits() - 1) / 64 + 1) as u32;
|
||||||
safe_memory.prime = prime;
|
safe_memory.prime = prime.clone();
|
||||||
|
|
||||||
Ok(WitnessCalculator {
|
Ok(WitnessCalculator {
|
||||||
instance,
|
instance,
|
||||||
memory: safe_memory,
|
memory: Some(safe_memory),
|
||||||
n64,
|
n64,
|
||||||
circom_version: version,
|
circom_version: version,
|
||||||
|
prime,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -148,7 +147,7 @@ impl WitnessCalculator {
|
|||||||
cfg_if::cfg_if! {
|
cfg_if::cfg_if! {
|
||||||
if #[cfg(feature = "circom-2")] {
|
if #[cfg(feature = "circom-2")] {
|
||||||
match version {
|
match version {
|
||||||
2 => new_circom2(instance, memory, version),
|
2 => new_circom2(instance, version),
|
||||||
1 => new_circom1(instance, memory, version),
|
1 => new_circom1(instance, memory, version),
|
||||||
_ => panic!("Unknown Circom version")
|
_ => panic!("Unknown Circom version")
|
||||||
}
|
}
|
||||||
@ -186,9 +185,9 @@ impl WitnessCalculator {
|
|||||||
) -> Result<Vec<BigInt>> {
|
) -> Result<Vec<BigInt>> {
|
||||||
self.instance.init(sanity_check)?;
|
self.instance.init(sanity_check)?;
|
||||||
|
|
||||||
let old_mem_free_pos = self.memory.free_pos();
|
let old_mem_free_pos = self.memory.as_ref().unwrap().free_pos();
|
||||||
let p_sig_offset = self.memory.alloc_u32();
|
let p_sig_offset = self.memory.as_mut().unwrap().alloc_u32();
|
||||||
let p_fr = self.memory.alloc_fr();
|
let p_fr = self.memory.as_mut().unwrap().alloc_fr();
|
||||||
|
|
||||||
// allocate the inputs
|
// allocate the inputs
|
||||||
for (name, values) in inputs.into_iter() {
|
for (name, values) in inputs.into_iter() {
|
||||||
@ -197,10 +196,17 @@ impl WitnessCalculator {
|
|||||||
self.instance
|
self.instance
|
||||||
.get_signal_offset32(p_sig_offset, 0, msb, lsb)?;
|
.get_signal_offset32(p_sig_offset, 0, msb, lsb)?;
|
||||||
|
|
||||||
let sig_offset = self.memory.read_u32(p_sig_offset as usize) as usize;
|
let sig_offset = self
|
||||||
|
.memory
|
||||||
|
.as_ref()
|
||||||
|
.unwrap()
|
||||||
|
.read_u32(p_sig_offset as usize) as usize;
|
||||||
|
|
||||||
for (i, value) in values.into_iter().enumerate() {
|
for (i, value) in values.into_iter().enumerate() {
|
||||||
self.memory.write_fr(p_fr as usize, &value)?;
|
self.memory
|
||||||
|
.as_mut()
|
||||||
|
.unwrap()
|
||||||
|
.write_fr(p_fr as usize, &value)?;
|
||||||
self.instance
|
self.instance
|
||||||
.set_signal(0, 0, (sig_offset + i) as u32, p_fr)?;
|
.set_signal(0, 0, (sig_offset + i) as u32, p_fr)?;
|
||||||
}
|
}
|
||||||
@ -211,11 +217,11 @@ impl WitnessCalculator {
|
|||||||
let n_vars = self.instance.get_n_vars()?;
|
let n_vars = self.instance.get_n_vars()?;
|
||||||
for i in 0..n_vars {
|
for i in 0..n_vars {
|
||||||
let ptr = self.instance.get_ptr_witness(i)? as usize;
|
let ptr = self.instance.get_ptr_witness(i)? as usize;
|
||||||
let el = self.memory.read_fr(ptr)?;
|
let el = self.memory.as_ref().unwrap().read_fr(ptr)?;
|
||||||
w.push(el);
|
w.push(el);
|
||||||
}
|
}
|
||||||
|
|
||||||
self.memory.set_free_pos(old_mem_free_pos);
|
self.memory.as_mut().unwrap().set_free_pos(old_mem_free_pos);
|
||||||
|
|
||||||
Ok(w)
|
Ok(w)
|
||||||
}
|
}
|
||||||
@ -289,20 +295,6 @@ impl WitnessCalculator {
|
|||||||
|
|
||||||
Ok(witness)
|
Ok(witness)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_witness_buffer(&self) -> Result<Vec<u8>> {
|
|
||||||
let ptr = self.instance.get_ptr_witness_buffer()? as usize;
|
|
||||||
|
|
||||||
let view = self.memory.memory.view::<u8>();
|
|
||||||
|
|
||||||
let len = self.instance.get_n_vars()? * self.n64 * 8;
|
|
||||||
let arr = view[ptr..ptr + len as usize]
|
|
||||||
.iter()
|
|
||||||
.map(Cell::get)
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
Ok(arr)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// callback hooks for debugging
|
// callback hooks for debugging
|
||||||
@ -469,7 +461,7 @@ mod tests {
|
|||||||
fn run_test(case: TestCase) {
|
fn run_test(case: TestCase) {
|
||||||
let mut wtns = WitnessCalculator::new(case.circuit_path).unwrap();
|
let mut wtns = WitnessCalculator::new(case.circuit_path).unwrap();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
wtns.memory.prime.to_str_radix(16),
|
wtns.prime.to_str_radix(16),
|
||||||
"30644E72E131A029B85045B68181585D2833E84879B9709143E1F593F0000001".to_lowercase()
|
"30644E72E131A029B85045B68181585D2833E84879B9709143E1F593F0000001".to_lowercase()
|
||||||
);
|
);
|
||||||
assert_eq!({ wtns.instance.get_n_vars().unwrap() }, case.n_vars);
|
assert_eq!({ wtns.instance.get_n_vars().unwrap() }, case.n_vars);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user