Add some tests for reverse_bits (#105)

This commit is contained in:
Justin Traglia 2023-01-31 16:51:35 +01:00 committed by GitHub
parent d3b061f84b
commit eb17071bf3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 55 additions and 11 deletions

View File

@ -145,10 +145,10 @@ static const uint64_t SCALE2_ROOT_OF_UNITY[][4] = {
}; };
/** The zero field element. */ /** The zero field element. */
static const fr_t fr_zero = {0L, 0L, 0L, 0L}; static const fr_t FR_ZERO = {0L, 0L, 0L, 0L};
/** This is 1 in Blst's `blst_fr` limb representation. Crazy but true. */ /** This is 1 in Blst's `blst_fr` limb representation. Crazy but true. */
static const fr_t fr_one = {0x00000001fffffffeL, 0x5884b7fa00034802L, 0x998c4fefecbc4ff5L, 0x1824b159acc5056fL}; static const fr_t FR_ONE = {0x00000001fffffffeL, 0x5884b7fa00034802L, 0x998c4fefecbc4ff5L, 0x1824b159acc5056fL};
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Memory Allocation Functions // Memory Allocation Functions
@ -226,7 +226,7 @@ static C_KZG_RET new_fr_array(fr_t **x, size_t n) {
* @param[in] b A non-zero byte * @param[in] b A non-zero byte
* @return The index of the highest set bit * @return The index of the highest set bit
*/ */
static int log_2_byte(byte b) { STATIC int log_2_byte(byte b) {
int r, shift; int r, shift;
r = (b > 0xF) << 2; r = (b > 0xF) << 2;
b >>= r; b >>= r;
@ -292,7 +292,7 @@ static void fr_div(fr_t *out, const fr_t *a, const fr_t *b) {
*/ */
static void fr_pow(fr_t *out, const fr_t *a, uint64_t n) { static void fr_pow(fr_t *out, const fr_t *a, uint64_t n) {
fr_t tmp = *a; fr_t tmp = *a;
*out = fr_one; *out = FR_ONE;
while (true) { while (true) {
if (n & 1) { if (n & 1) {
@ -556,7 +556,7 @@ static bool is_power_of_two(uint64_t n) {
* @param[in] a The integer to be reversed * @param[in] a The integer to be reversed
* @return An integer with the bits of @p a reversed * @return An integer with the bits of @p a reversed
*/ */
static uint32_t reverse_bits(uint32_t a) { STATIC uint32_t reverse_bits(uint32_t a) {
return rev_4byte(a); return rev_4byte(a);
} }
@ -862,7 +862,7 @@ static void poly_lincomb(Polynomial *out, const Polynomial *vectors, const fr_t
fr_t tmp; fr_t tmp;
uint64_t i, j; uint64_t i, j;
for (j = 0; j < FIELD_ELEMENTS_PER_BLOB; j++) for (j = 0; j < FIELD_ELEMENTS_PER_BLOB; j++)
out->evals[j] = fr_zero; out->evals[j] = FR_ZERO;
for (i = 0; i < n; i++) { for (i = 0; i < n; i++) {
for (j = 0; j < FIELD_ELEMENTS_PER_BLOB; j++) { for (j = 0; j < FIELD_ELEMENTS_PER_BLOB; j++) {
blst_fr_mul(&tmp, &scalars[i], &vectors[i].evals[j]); blst_fr_mul(&tmp, &scalars[i], &vectors[i].evals[j]);
@ -881,7 +881,7 @@ static void poly_lincomb(Polynomial *out, const Polynomial *vectors, const fr_t
* @param[in] n The number of powers to compute * @param[in] n The number of powers to compute
*/ */
static void compute_powers(fr_t *out, fr_t *x, uint64_t n) { static void compute_powers(fr_t *out, fr_t *x, uint64_t n) {
fr_t current_power = fr_one; fr_t current_power = FR_ONE;
for (uint64_t i = 0; i < n; i++) { for (uint64_t i = 0; i < n; i++) {
out[i] = current_power; out[i] = current_power;
blst_fr_mul(&current_power, &current_power, x); blst_fr_mul(&current_power, &current_power, x);
@ -927,7 +927,7 @@ STATIC C_KZG_RET evaluate_polynomial_in_evaluation_form(fr_t *out, const Polynom
ret = fr_batch_inv(inverses, inverses_in, FIELD_ELEMENTS_PER_BLOB); ret = fr_batch_inv(inverses, inverses_in, FIELD_ELEMENTS_PER_BLOB);
if (ret != C_KZG_OK) goto out; if (ret != C_KZG_OK) goto out;
*out = fr_zero; *out = FR_ZERO;
for (i = 0; i < FIELD_ELEMENTS_PER_BLOB; i++) { for (i = 0; i < FIELD_ELEMENTS_PER_BLOB; i++) {
blst_fr_mul(&tmp, &inverses[i], &roots_of_unity[i]); blst_fr_mul(&tmp, &inverses[i], &roots_of_unity[i]);
blst_fr_mul(&tmp, &tmp, &p->evals[i]); blst_fr_mul(&tmp, &tmp, &p->evals[i]);
@ -936,7 +936,7 @@ STATIC C_KZG_RET evaluate_polynomial_in_evaluation_form(fr_t *out, const Polynom
fr_from_uint64(&tmp, FIELD_ELEMENTS_PER_BLOB); fr_from_uint64(&tmp, FIELD_ELEMENTS_PER_BLOB);
fr_div(out, out, &tmp); fr_div(out, out, &tmp);
fr_pow(&tmp, x, FIELD_ELEMENTS_PER_BLOB); fr_pow(&tmp, x, FIELD_ELEMENTS_PER_BLOB);
blst_fr_sub(&tmp, &tmp, &fr_one); blst_fr_sub(&tmp, &tmp, &FR_ONE);
blst_fr_mul(out, out, &tmp); blst_fr_mul(out, out, &tmp);
out: out:
@ -1125,7 +1125,7 @@ C_KZG_RET compute_kzg_proof_impl(KZGProof *out, const Polynomial *polynomial, co
} }
if (m) { // ω_m == z if (m) { // ω_m == z
q.evals[--m] = fr_zero; q.evals[--m] = FR_ZERO;
for (i = 0; i < FIELD_ELEMENTS_PER_BLOB; i++) { for (i = 0; i < FIELD_ELEMENTS_PER_BLOB; i++) {
if (i == m) continue; if (i == m) continue;
// (p_i - y) * ω_i / (z * (z - ω_i)) // (p_i - y) * ω_i / (z * (z - ω_i))
@ -1389,7 +1389,7 @@ static C_KZG_RET fft_g1(g1_t *out, const g1_t *in, bool inverse, uint64_t n, con
* @retval C_CZK_BADARGS Invalid parameters were supplied * @retval C_CZK_BADARGS Invalid parameters were supplied
*/ */
static C_KZG_RET expand_root_of_unity(fr_t *out, const fr_t *root, uint64_t width) { static C_KZG_RET expand_root_of_unity(fr_t *out, const fr_t *root, uint64_t width) {
out[0] = fr_one; out[0] = FR_ONE;
out[1] = *root; out[1] = *root;
for (uint64_t i = 2; !fr_is_one(&out[i - 1]); i++) { for (uint64_t i = 2; !fr_is_one(&out[i - 1]); i++) {

View File

@ -135,6 +135,7 @@ void bytes_from_g1(Bytes48 *out, const g1_t *in);
C_KZG_RET evaluate_polynomial_in_evaluation_form(fr_t *out, const Polynomial *p, const fr_t *x, const KZGSettings *s); C_KZG_RET evaluate_polynomial_in_evaluation_form(fr_t *out, const Polynomial *p, const fr_t *x, const KZGSettings *s);
C_KZG_RET blob_to_polynomial(Polynomial *p, const Blob *blob); C_KZG_RET blob_to_polynomial(Polynomial *p, const Blob *blob);
C_KZG_RET bytes_to_bls_field(fr_t *out, const Bytes32 *b); C_KZG_RET bytes_to_bls_field(fr_t *out, const Bytes32 *b);
uint32_t reverse_bits(uint32_t a);
#endif #endif

View File

@ -75,6 +75,12 @@ static void bytes48_from_hex(Bytes48 *out, const char *hex) {
} }
} }
static void get_rand_uint32(uint32_t *out) {
Bytes32 b;
get_rand_bytes32(&b);
*out = *(uint32_t *)(b.bytes);
}
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Tests for blob_to_kzg_commitment // Tests for blob_to_kzg_commitment
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@ -313,6 +319,39 @@ static void test_validate_kzg_g1__fails_with_b_flag_and_a_flag_true(void) {
ASSERT_EQUALS(ret, C_KZG_BADARGS); ASSERT_EQUALS(ret, C_KZG_BADARGS);
} }
///////////////////////////////////////////////////////////////////////////////
// Tests for reverse_bits
///////////////////////////////////////////////////////////////////////////////
static void test_reverse_bits__round_trip(void) {
uint32_t original;
uint32_t reversed;
uint32_t reversed_reversed;
get_rand_uint32(&original);
reversed = reverse_bits(original);
reversed_reversed = reverse_bits(reversed);
ASSERT_EQUALS(reversed_reversed, original);
}
static void test_reverse_bits__all_bits_are_zero(void) {
uint32_t original = 0b00000000000000000000000000000000;
uint32_t reversed = 0b00000000000000000000000000000000;
ASSERT_EQUALS(reverse_bits(original), reversed);
}
static void test_reverse_bits__some_bits_are_one(void) {
uint32_t original = 0b10101000011111100000000000000010;
uint32_t reversed = 0b01000000000000000111111000010101;
ASSERT_EQUALS(reverse_bits(original), reversed);
}
static void test_reverse_bits__all_bits_are_one(void) {
uint32_t original = 0b11111111111111111111111111111111;
uint32_t reversed = 0b11111111111111111111111111111111;
ASSERT_EQUALS(reverse_bits(original), reversed);
}
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Tests for compute_kzg_proof // Tests for compute_kzg_proof
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@ -402,6 +441,10 @@ int main(void) {
RUN(test_validate_kzg_g1__fails_with_wrong_c_flag); RUN(test_validate_kzg_g1__fails_with_wrong_c_flag);
RUN(test_validate_kzg_g1__fails_with_b_flag_and_x_nonzero); RUN(test_validate_kzg_g1__fails_with_b_flag_and_x_nonzero);
RUN(test_validate_kzg_g1__fails_with_b_flag_and_a_flag_true); RUN(test_validate_kzg_g1__fails_with_b_flag_and_a_flag_true);
RUN(test_reverse_bits__round_trip);
RUN(test_reverse_bits__all_bits_are_zero);
RUN(test_reverse_bits__some_bits_are_one);
RUN(test_reverse_bits__all_bits_are_one);
RUN(test_compute_and_verify_kzg_proof); RUN(test_compute_and_verify_kzg_proof);
teardown(); teardown();