Update Compute*KZGProof in rust bindings (#183)

* Update Compute*KZGProof in rust bindings

* Remove the boxing from the blobs

and implement get_blobs() a bit less promiscuously

* Improve pattern matching style

* Run `cargo fmt`

* Remove a println

* No need to clone commitments
This commit is contained in:
George Kadianakis 2023-03-09 13:00:17 +02:00 committed by GitHub
parent c295688099
commit 599ae2fe21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 85 additions and 44 deletions

View File

@ -220,7 +220,8 @@ extern "C" {
) -> C_KZG_RET;
pub fn compute_kzg_proof(
out: *mut KZGProof,
proof_out: *mut KZGProof,
y_out: *mut Bytes32,
blob: *const Blob,
z_bytes: *const Bytes32,
s: *const KZGSettings,
@ -229,6 +230,7 @@ extern "C" {
pub fn compute_blob_kzg_proof(
out: *mut KZGProof,
blob: *const Blob,
commitment_bytes: *const Bytes48,
s: *const KZGSettings,
) -> C_KZG_RET;

View File

@ -110,7 +110,7 @@ impl Drop for KZGSettings {
}
impl Blob {
pub fn from_bytes(bytes: &[u8]) -> Result<Box<Self>, Error> {
pub fn from_bytes(bytes: &[u8]) -> Result<Self, Error> {
if bytes.len() != BYTES_PER_BLOB {
return Err(Error::InvalidBytesLength(format!(
"Invalid byte length. Expected {} got {}",
@ -120,7 +120,7 @@ impl Blob {
}
let mut new_bytes = [0; BYTES_PER_BLOB];
new_bytes.copy_from_slice(bytes);
Ok(Box::new(Self { bytes: new_bytes }))
Ok(Self { bytes: new_bytes })
}
}
@ -182,22 +182,38 @@ impl KZGProof {
blob: Blob,
z_bytes: Bytes32,
kzg_settings: &KZGSettings,
) -> Result<Self, Error> {
) -> Result<(Self, Bytes32), Error> {
let mut kzg_proof = MaybeUninit::<KZGProof>::uninit();
let mut y_out = MaybeUninit::<Bytes32>::uninit();
unsafe {
let res = compute_kzg_proof(kzg_proof.as_mut_ptr(), &blob, &z_bytes, kzg_settings);
let res = compute_kzg_proof(
kzg_proof.as_mut_ptr(),
y_out.as_mut_ptr(),
&blob,
&z_bytes,
kzg_settings,
);
if let C_KZG_RET::C_KZG_OK = res {
Ok(kzg_proof.assume_init())
Ok((kzg_proof.assume_init(), y_out.assume_init()))
} else {
Err(Error::CError(res))
}
}
}
pub fn compute_blob_kzg_proof(blob: Blob, kzg_settings: &KZGSettings) -> Result<Self, Error> {
pub fn compute_blob_kzg_proof(
blob: Blob,
commitment_bytes: Bytes48,
kzg_settings: &KZGSettings,
) -> Result<Self, Error> {
let mut kzg_proof = MaybeUninit::<KZGProof>::uninit();
unsafe {
let res = compute_blob_kzg_proof(kzg_proof.as_mut_ptr(), &blob, kzg_settings);
let res = compute_blob_kzg_proof(
kzg_proof.as_mut_ptr(),
&blob,
&commitment_bytes,
kzg_settings,
);
if let C_KZG_RET::C_KZG_OK = res {
Ok(kzg_proof.assume_init())
} else {
@ -397,16 +413,17 @@ mod tests {
.collect();
let commitments: Vec<Bytes48> = blobs
.clone()
.into_iter()
.map(|blob| KZGCommitment::blob_to_kzg_commitment(blob, &kzg_settings).unwrap())
.iter()
.map(|blob| KZGCommitment::blob_to_kzg_commitment(*blob, &kzg_settings).unwrap())
.map(|commitment| commitment.to_bytes())
.collect();
let proofs: Vec<Bytes48> = blobs
.clone()
.into_iter()
.map(|blob| KZGProof::compute_blob_kzg_proof(blob, &kzg_settings).unwrap())
.iter()
.zip(commitments.iter())
.map(|(blob, commitment)| {
KZGProof::compute_blob_kzg_proof(*blob, *commitment, &kzg_settings).unwrap()
})
.map(|proof| proof.to_bytes())
.collect();
@ -469,7 +486,7 @@ mod tests {
continue;
};
match KZGCommitment::blob_to_kzg_commitment(*blob, &kzg_settings) {
match KZGCommitment::blob_to_kzg_commitment(blob, &kzg_settings) {
Ok(res) => assert_eq!(res.bytes, test.get_output().unwrap().bytes),
_ => assert!(test.get_output().is_none()),
}
@ -491,8 +508,11 @@ mod tests {
continue;
};
match KZGProof::compute_kzg_proof(*blob, z, &kzg_settings) {
Ok(res) => assert_eq!(res.bytes, test.get_output().unwrap().bytes),
match KZGProof::compute_kzg_proof(blob, z, &kzg_settings) {
Ok((proof, y)) => {
assert_eq!(proof.bytes, test.get_output().unwrap().0.bytes);
assert_eq!(y.bytes, test.get_output().unwrap().1.bytes);
}
_ => assert!(test.get_output().is_none()),
}
}
@ -508,12 +528,15 @@ mod tests {
for test_file in glob::glob(COMPUTE_BLOB_KZG_PROOF_TESTS).unwrap() {
let yaml_data = fs::read_to_string(test_file.unwrap()).unwrap();
let test: compute_blob_kzg_proof::Test = serde_yaml::from_str(&yaml_data).unwrap();
let Ok(blob) = test.input.get_blob() else {
let (Ok(blob), Ok(commitment)) = (
test.input.get_blob(),
test.input.get_commitment()
) else {
assert!(test.get_output().is_none());
continue;
};
match KZGProof::compute_blob_kzg_proof(*blob, &kzg_settings) {
match KZGProof::compute_blob_kzg_proof(blob, commitment, &kzg_settings) {
Ok(res) => assert_eq!(res.bytes, test.get_output().unwrap().bytes),
_ => assert!(test.get_output().is_none()),
}
@ -566,7 +589,7 @@ mod tests {
continue;
};
match KZGProof::verify_blob_kzg_proof(*blob, commitment, proof, &kzg_settings) {
match KZGProof::verify_blob_kzg_proof(blob, commitment, proof, &kzg_settings) {
Ok(res) => assert_eq!(res, test.get_output().unwrap()),
_ => assert!(test.get_output().is_none()),
}
@ -593,13 +616,9 @@ mod tests {
};
match KZGProof::verify_blob_kzg_proof_batch(
blobs
.into_iter()
.map(|b| *b)
.collect::<Vec<Blob>>()
.as_slice(),
commitments.as_slice(),
proofs.as_slice(),
&blobs,
&commitments,
&proofs,
&kzg_settings,
) {
Ok(res) => assert_eq!(res, test.get_output().unwrap()),

View File

@ -9,7 +9,7 @@ pub struct Input<'a> {
}
impl Input<'_> {
pub fn get_blob(&self) -> Result<Box<Blob>, Error> {
pub fn get_blob(&self) -> Result<Blob, Error> {
let hex_str = self.blob.replace("0x", "");
let bytes = hex::decode(hex_str).unwrap();
Blob::from_bytes(&bytes)

View File

@ -6,14 +6,21 @@ use serde::Deserialize;
#[derive(Deserialize)]
pub struct Input<'a> {
blob: &'a str,
commitment: &'a str,
}
impl Input<'_> {
pub fn get_blob(&self) -> Result<Box<Blob>, Error> {
pub fn get_blob(&self) -> Result<Blob, Error> {
let hex_str = self.blob.replace("0x", "");
let bytes = hex::decode(hex_str).unwrap();
Blob::from_bytes(&bytes)
}
pub fn get_commitment(&self) -> Result<Bytes48, Error> {
let hex_str = self.commitment.replace("0x", "");
let bytes = hex::decode(hex_str).unwrap();
Bytes48::from_bytes(&bytes)
}
}
#[derive(Deserialize)]

View File

@ -10,7 +10,7 @@ pub struct Input<'a> {
}
impl Input<'_> {
pub fn get_blob(&self) -> Result<Box<Blob>, Error> {
pub fn get_blob(&self) -> Result<Blob, Error> {
let hex_str = self.blob.replace("0x", "");
let bytes = hex::decode(hex_str).unwrap();
Blob::from_bytes(&bytes)
@ -28,14 +28,23 @@ pub struct Test<'a> {
#[serde(borrow)]
pub input: Input<'a>,
#[serde(borrow)]
output: Option<&'a str>,
output: Option<(&'a str, &'a str)>,
}
impl Test<'_> {
pub fn get_output(&self) -> Option<Bytes48> {
self.output
.map(|s| s.replace("0x", ""))
.map(|hex_str| hex::decode(hex_str).unwrap())
.map(|bytes| Bytes48::from_bytes(&bytes).unwrap())
pub fn get_output(&self) -> Option<(Bytes48, Bytes32)> {
if self.output.is_none() {
return None;
}
let proof_hex = self.output.as_ref().unwrap().0.replace("0x", "");
let proof_bytes = hex::decode(proof_hex).unwrap();
let proof = Bytes48::from_bytes(&proof_bytes).unwrap();
let z_hex = self.output.as_ref().unwrap().1.replace("0x", "");
let z_bytes = hex::decode(z_hex).unwrap();
let z = Bytes32::from_bytes(&z_bytes).unwrap();
Some((proof, z))
}
}

View File

@ -11,7 +11,7 @@ pub struct Input<'a> {
}
impl Input<'_> {
pub fn get_blob(&self) -> Result<Box<Blob>, Error> {
pub fn get_blob(&self) -> Result<Blob, Error> {
let hex_str = self.blob.replace("0x", "");
let bytes = hex::decode(hex_str).unwrap();
Blob::from_bytes(&bytes)

View File

@ -11,13 +11,17 @@ pub struct Input {
}
impl Input {
pub fn get_blobs(&self) -> Result<Vec<Box<Blob>>, Error> {
self.blobs
.iter()
.map(|s| s.replace("0x", ""))
.map(|hex_str| hex::decode(hex_str).unwrap())
.map(|bytes| Blob::from_bytes(bytes.as_slice()))
.collect::<Result<Vec<Box<Blob>>, Error>>()
pub fn get_blobs(&self) -> Result<Vec<Blob>, Error> {
let mut v: Vec<Blob> = Vec::new();
for blob in &self.blobs {
let blob_hex = blob.replace("0x", "");
let blob_bytes = hex::decode(blob_hex).unwrap();
let b = Blob::from_bytes(blob_bytes.as_slice())?;
v.push(b);
}
return Ok(v);
}
pub fn get_commitments(&self) -> Result<Vec<Bytes48>, Error> {