From 29a7555c084e13f2c59801be953c2e1f26a694c9 Mon Sep 17 00:00:00 2001 From: Georgios Konstantopoulos Date: Mon, 26 Jul 2021 17:38:29 +0300 Subject: [PATCH] zkey: expose only 1 method - everything else private --- src/zkey.rs | 159 +++++++++++++++++++--------------------------------- 1 file changed, 59 insertions(+), 100 deletions(-) diff --git a/src/zkey.rs b/src/zkey.rs index d1a9679..1c8531e 100644 --- a/src/zkey.rs +++ b/src/zkey.rs @@ -1,4 +1,4 @@ -//! ZKey +//! ZKey Parsing //! //! Each ZKey file is broken into sections: //! Header(1) @@ -32,30 +32,35 @@ use byteorder::{LittleEndian, ReadBytesExt}; use std::{ collections::HashMap, - io::{Cursor, Read, Result as IoResult}, + io::{Read, Result as IoResult, Seek, SeekFrom}, }; use ark_bn254::{Bn254, Fq, Fq2, G1Affine, G2Affine}; use ark_groth16::{ProvingKey, VerifyingKey}; -use ark_serialize::CanonicalSerialize; use num_traits::Zero; #[derive(Clone, Debug)] -pub struct Section { +struct Section { position: u64, size: usize, } +/// Reads a SnarkJS ZKey file into an Arkworks ProvingKey. +pub fn read_zkey(reader: &mut R) -> IoResult> { + let mut binfile = BinFile::new(reader)?; + binfile.proving_key() +} + #[derive(Debug)] -pub struct BinFile<'a> { +struct BinFile<'a, R> { ftype: String, version: u32, sections: HashMap>, - reader: &'a mut Cursor<&'a [u8]>, + reader: &'a mut R, } -impl<'a> BinFile<'a> { - pub fn new(reader: &'a mut Cursor<&'a [u8]>) -> IoResult { +impl<'a, R: Read + Seek> BinFile<'a, R> { + fn new(reader: &'a mut R) -> IoResult { let mut magic = [0u8; 4]; reader.read_exact(&mut magic)?; @@ -70,11 +75,11 @@ impl<'a> BinFile<'a> { let section = sections.entry(section_id).or_insert_with(Vec::new); section.push(Section { - position: reader.position(), + position: reader.stream_position()?, size: section_length as usize, }); - reader.set_position(reader.position() + section_length); + reader.seek(SeekFrom::Current(section_length as i64))?; } Ok(Self { @@ -85,7 +90,7 @@ impl<'a> BinFile<'a> { }) } - pub fn proving_key(&mut self) -> IoResult> { + fn proving_key(&mut self) -> IoResult> { let header = self.groth_header()?; let ic = self.ic(header.n_public)?; @@ -121,64 +126,60 @@ impl<'a> BinFile<'a> { self.sections.get(&id).unwrap()[0].clone() } - pub fn groth_header(&mut self) -> IoResult { + fn groth_header(&mut self) -> IoResult { let section = self.get_section(2); let header = HeaderGroth::new(&mut self.reader, §ion)?; Ok(header) } - pub fn ic(&mut self, n_public: usize) -> IoResult> { + fn ic(&mut self, n_public: usize) -> IoResult> { // the range is non-inclusive so we do +1 to get all inputs self.g1_section(n_public + 1, 3) } // Section 4 is the coefficients, we ignore it - pub fn a_query(&mut self, n_vars: usize) -> IoResult> { + fn a_query(&mut self, n_vars: usize) -> IoResult> { self.g1_section(n_vars, 5) } - pub fn b_g1_query(&mut self, n_vars: usize) -> IoResult> { + fn b_g1_query(&mut self, n_vars: usize) -> IoResult> { self.g1_section(n_vars, 6) } - pub fn b_g2_query(&mut self, n_vars: usize) -> IoResult> { + fn b_g2_query(&mut self, n_vars: usize) -> IoResult> { self.g2_section(n_vars, 7) } - pub fn l_query(&mut self, n_vars: usize) -> IoResult> { + fn l_query(&mut self, n_vars: usize) -> IoResult> { self.g1_section(n_vars, 8) } - pub fn h_query(&mut self, n_vars: usize) -> IoResult> { + fn h_query(&mut self, n_vars: usize) -> IoResult> { self.g1_section(n_vars, 9) } fn g1_section(&mut self, num: usize, section_id: usize) -> IoResult> { let section = self.get_section(section_id as u32); - deserialize_g1_vec( - &self.reader.get_ref()[section.position as usize..], - num as u32, - ) + self.reader.seek(SeekFrom::Start(section.position))?; + deserialize_g1_vec(self.reader, num as u32) } fn g2_section(&mut self, num: usize, section_id: usize) -> IoResult> { let section = self.get_section(section_id as u32); - deserialize_g2_vec( - &self.reader.get_ref()[section.position as usize..], - num as u32, - ) + self.reader.seek(SeekFrom::Start(section.position))?; + deserialize_g2_vec(self.reader, num as u32) } } #[derive(Default, Clone, Debug, CanonicalDeserialize)] pub struct ZVerifyingKey { - pub alpha_g1: G1Affine, - pub beta_g1: G1Affine, - pub beta_g2: G2Affine, - pub gamma_g2: G2Affine, - pub delta_g1: G1Affine, - pub delta_g2: G2Affine, + alpha_g1: G1Affine, + beta_g1: G1Affine, + beta_g2: G2Affine, + gamma_g2: G2Affine, + delta_g1: G1Affine, + delta_g2: G2Affine, } impl ZVerifyingKey { @@ -202,25 +203,25 @@ impl ZVerifyingKey { } #[derive(Clone, Debug)] -pub struct HeaderGroth { - pub n8q: u32, - pub q: BigInteger256, +struct HeaderGroth { + n8q: u32, + q: BigInteger256, - pub n8r: u32, - pub r: BigInteger256, + n8r: u32, + r: BigInteger256, - pub n_vars: usize, - pub n_public: usize, + n_vars: usize, + n_public: usize, - pub domain_size: u32, - pub power: u32, + domain_size: u32, + power: u32, - pub verifying_key: ZVerifyingKey, + verifying_key: ZVerifyingKey, } impl HeaderGroth { - pub fn new(reader: &mut Cursor<&[u8]>, section: &Section) -> IoResult { - reader.set_position(section.position); + fn new(reader: &mut R, section: &Section) -> IoResult { + reader.seek(SeekFrom::Start(section.position))?; Self::read(reader) } @@ -282,31 +283,18 @@ fn deserialize_g2(reader: &mut R) -> IoResult { Ok(G2Affine::new(f1, f2, infinity)) } -fn deserialize_g1_vec(buf: &[u8], n_vars: u32) -> IoResult> { - let size = G1Affine::zero().uncompressed_size(); - let mut v = vec![]; - for i in 0..n_vars as usize { - let el = deserialize_g1(&mut &buf[i * size..(i + 1) * size])?; - v.push(el); - } - Ok(v) +fn deserialize_g1_vec(reader: &mut R, n_vars: u32) -> IoResult> { + (0..n_vars).map(|_| deserialize_g1(reader)).collect() } -fn deserialize_g2_vec(buf: &[u8], n_vars: u32) -> IoResult> { - let size = G2Affine::zero().uncompressed_size(); - let mut v = vec![]; - for i in 0..n_vars as usize { - let el = deserialize_g2(&mut &buf[i * size..(i + 1) * size])?; - v.push(el); - } - Ok(v) +fn deserialize_g2_vec(reader: &mut R, n_vars: u32) -> IoResult> { + (0..n_vars).map(|_| deserialize_g2(reader)).collect() } #[cfg(test)] mod tests { use super::*; use ark_bn254::{G1Projective, G2Projective}; - use memmap::*; use num_bigint::BigUint; use serde_json::Value; use std::fs::File; @@ -420,7 +408,7 @@ mod tests { .collect::>(); let expected = vec![g1_one(); n_vars]; - let de = deserialize_g1_vec(&buf[..], n_vars as u32).unwrap(); + let de = deserialize_g1_vec(&mut &buf[..], n_vars as u32).unwrap(); assert_eq!(expected, de); } @@ -444,7 +432,7 @@ mod tests { .collect::>(); let expected = vec![g2_one(); n_vars]; - let de = deserialize_g2_vec(&buf[..], n_vars as u32).unwrap(); + let de = deserialize_g2_vec(&mut &buf[..], n_vars as u32).unwrap(); assert_eq!(expected, de); } @@ -465,14 +453,8 @@ mod tests { // Then: // `snarkjs zkey new circuit.r1cs powersOfTau28_hez_final_10.ptau test.zkey` let path = "./test-vectors/test.zkey"; - let file = File::open(path).unwrap(); - let map = unsafe { - MmapOptions::new() - .map(&file) - .expect("unable to create a memory map") - }; - let mut reader = Cursor::new(map.as_ref()); - let mut binfile = BinFile::new(&mut reader).unwrap(); + let mut file = File::open(path).unwrap(); + let mut binfile = BinFile::new(&mut file).unwrap(); let header = binfile.groth_header().unwrap(); assert_eq!(header.n_vars, 4); assert_eq!(header.n_public, 1); @@ -483,15 +465,8 @@ mod tests { #[test] fn deser_key() { let path = "./test-vectors/test.zkey"; - let file = File::open(path).unwrap(); - let map = unsafe { - MmapOptions::new() - .map(&file) - .expect("unable to create a memory map") - }; - let mut reader = Cursor::new(map.as_ref()); - let mut binfile = BinFile::new(&mut reader).unwrap(); - let params = binfile.proving_key().unwrap(); + let mut file = File::open(path).unwrap(); + let params = read_zkey(&mut file).unwrap(); // Check IC let expected = vec![ @@ -710,16 +685,8 @@ mod tests { #[test] fn deser_vk() { let path = "./test-vectors/test.zkey"; - let file = File::open(path).unwrap(); - let map = unsafe { - MmapOptions::new() - .map(&file) - .expect("unable to create a memory map") - }; - let mut reader = Cursor::new(map.as_ref()); - let mut binfile = BinFile::new(&mut reader).unwrap(); - - let params = binfile.proving_key().unwrap(); + let mut file = File::open(path).unwrap(); + let params = read_zkey(&mut file).unwrap(); let json = std::fs::read_to_string("./test-vectors/verification_key.json").unwrap(); let json: Value = serde_json::from_str(&json).unwrap(); @@ -799,16 +766,8 @@ mod tests { #[test] fn verify_proof_with_zkey() { let path = "./test-vectors/test.zkey"; - let file = File::open(path).unwrap(); - let map = unsafe { - MmapOptions::new() - .map(&file) - .expect("unable to create a memory map") - }; - let mut reader = Cursor::new(map.as_ref()); - let mut binfile = BinFile::new(&mut reader).unwrap(); - - let params = binfile.proving_key().unwrap(); + let mut file = File::open(path).unwrap(); + let params = read_zkey(&mut file).unwrap(); // binfile.proving_key().unwrap(); let cfg = CircuitConfig::::new( "./test-vectors/mycircuit.wasm",