diff --git a/src/num.h b/src/num.h index 596a122..534e095 100644 --- a/src/num.h +++ b/src/num.h @@ -45,6 +45,9 @@ void static secp256k1_num_mod_mul(secp256k1_num_t *r, const secp256k1_num_t *a, /** Compare the absolute value of two numbers. */ int static secp256k1_num_cmp(const secp256k1_num_t *a, const secp256k1_num_t *b); +/** Test whether two number are equal (including sign). */ +int static secp256k1_num_eq(const secp256k1_num_t *a, const secp256k1_num_t *b); + /** Add two (signed) numbers. */ void static secp256k1_num_add(secp256k1_num_t *r, const secp256k1_num_t *a, const secp256k1_num_t *b); diff --git a/src/num_gmp_impl.h b/src/num_gmp_impl.h index 75ef9dc..f97acd3 100644 --- a/src/num_gmp_impl.h +++ b/src/num_gmp_impl.h @@ -161,6 +161,13 @@ int static secp256k1_num_cmp(const secp256k1_num_t *a, const secp256k1_num_t *b) return mpn_cmp(a->data, b->data, a->limbs); } +int static secp256k1_num_eq(const secp256k1_num_t *a, const secp256k1_num_t *b) { + if (a->limbs > b->limbs) return 0; + if (a->limbs < b->limbs) return 0; + if ((a->neg && !secp256k1_num_is_zero(a)) != (b->neg && !secp256k1_num_is_zero(b))) return 0; + return mpn_cmp(a->data, b->data, a->limbs) == 0; +} + void static secp256k1_num_subadd(secp256k1_num_t *r, const secp256k1_num_t *a, const secp256k1_num_t *b, int bneg) { if (!(b->neg ^ bneg ^ a->neg)) { // a and b have the same sign r->neg = a->neg; diff --git a/src/num_openssl_impl.h b/src/num_openssl_impl.h index 948b2e7..262dabf 100644 --- a/src/num_openssl_impl.h +++ b/src/num_openssl_impl.h @@ -54,7 +54,11 @@ void static secp256k1_num_mod_mul(secp256k1_num_t *r, const secp256k1_num_t *a, } int static secp256k1_num_cmp(const secp256k1_num_t *a, const secp256k1_num_t *b) { - return BN_cmp(&a->bn, &b->bn); + return BN_ucmp(&a->bn, &b->bn); +} + +int static secp256k1_num_eq(const secp256k1_num_t *a, const secp256k1_num_t *b) { + return BN_cmp(&a->bn, &b->bn) == 0; } void static secp256k1_num_add(secp256k1_num_t *r, const secp256k1_num_t *a, const secp256k1_num_t *b) { diff --git a/src/tests.c b/src/tests.c index 3fd8f9a..5827eae 100644 --- a/src/tests.c +++ b/src/tests.c @@ -67,11 +67,11 @@ void test_num_copy_inc_cmp() { secp256k1_num_init(&n2); random_num_order(&n1); secp256k1_num_copy(&n2, &n1); - CHECK(secp256k1_num_cmp(&n1, &n2) == 0); - CHECK(secp256k1_num_cmp(&n2, &n1) == 0); + CHECK(secp256k1_num_eq(&n1, &n2)); + CHECK(secp256k1_num_eq(&n2, &n1)); secp256k1_num_inc(&n2); - CHECK(secp256k1_num_cmp(&n1, &n2) != 0); - CHECK(secp256k1_num_cmp(&n2, &n1) != 0); + CHECK(!secp256k1_num_eq(&n1, &n2)); + CHECK(!secp256k1_num_eq(&n2, &n1)); secp256k1_num_free(&n1); secp256k1_num_free(&n2); } @@ -85,7 +85,7 @@ void test_num_get_set_hex() { char c[64]; secp256k1_num_get_hex(c, 64, &n1); secp256k1_num_set_hex(&n2, c, 64); - CHECK(secp256k1_num_cmp(&n1, &n2) == 0); + CHECK(secp256k1_num_eq(&n1, &n2)); for (int i=0; i<64; i++) { // check whether the lower 4 bits correspond to the last hex character int low1 = secp256k1_num_shift(&n1, 4); @@ -96,7 +96,7 @@ void test_num_get_set_hex() { memmove(c+1, c, 63); c[0] = '0'; secp256k1_num_set_hex(&n2, c, 64); - CHECK(secp256k1_num_cmp(&n1, &n2) == 0); + CHECK(secp256k1_num_eq(&n1, &n2)); } secp256k1_num_free(&n2); secp256k1_num_free(&n1); @@ -110,7 +110,7 @@ void test_num_get_set_bin() { unsigned char c[32]; secp256k1_num_get_bin(c, 32, &n1); secp256k1_num_set_bin(&n2, c, 32); - CHECK(secp256k1_num_cmp(&n1, &n2) == 0); + CHECK(secp256k1_num_eq(&n1, &n2)); for (int i=0; i<32; i++) { // check whether the lower 8 bits correspond to the last byte int low1 = secp256k1_num_shift(&n1, 8); @@ -120,7 +120,7 @@ void test_num_get_set_bin() { memmove(c+1, c, 31); c[0] = 0; secp256k1_num_set_bin(&n2, c, 32); - CHECK(secp256k1_num_cmp(&n1, &n2) == 0); + CHECK(secp256k1_num_eq(&n1, &n2)); } secp256k1_num_free(&n2); secp256k1_num_free(&n1); @@ -159,21 +159,25 @@ void test_num_negate() { secp256k1_num_negate(&n1); // n1 = -R CHECK(secp256k1_num_is_neg(&n1) != secp256k1_num_is_neg(&n2)); secp256k1_num_negate(&n1); // n1 = R - CHECK(secp256k1_num_cmp(&n1, &n2) == 0); - CHECK(secp256k1_num_is_neg(&n1) == secp256k1_num_is_neg(&n2)); + CHECK(secp256k1_num_eq(&n1, &n2)); secp256k1_num_free(&n2); secp256k1_num_free(&n1); } void test_num_add_sub() { + int r = secp256k1_rand32(); secp256k1_num_t n1; secp256k1_num_t n2; secp256k1_num_init(&n1); secp256k1_num_init(&n2); random_num_order_test(&n1); // n1 = R1 - random_num_negate(&n1); + if (r & 1) { + random_num_negate(&n1); + } random_num_order_test(&n2); // n2 = R2 - random_num_negate(&n2); + if (r & 2) { + random_num_negate(&n2); + } secp256k1_num_t n1p2, n2p1, n1m2, n2m1; secp256k1_num_init(&n1p2); secp256k1_num_init(&n2p1); @@ -183,16 +187,16 @@ void test_num_add_sub() { secp256k1_num_add(&n2p1, &n2, &n1); // n2p1 = R2 + R1 secp256k1_num_sub(&n1m2, &n1, &n2); // n1m2 = R1 - R2 secp256k1_num_sub(&n2m1, &n2, &n1); // n2m1 = R2 - R1 - CHECK(secp256k1_num_cmp(&n1p2, &n2p1) == 0); - CHECK(secp256k1_num_cmp(&n1p2, &n1m2) != 0); + CHECK(secp256k1_num_eq(&n1p2, &n2p1)); + CHECK(!secp256k1_num_eq(&n1p2, &n1m2)); secp256k1_num_negate(&n2m1); // n2m1 = -R2 + R1 - CHECK(secp256k1_num_cmp(&n2m1, &n1m2) == 0); - CHECK(secp256k1_num_cmp(&n2m1, &n1) != 0); + CHECK(secp256k1_num_eq(&n2m1, &n1m2)); + CHECK(!secp256k1_num_eq(&n2m1, &n1)); secp256k1_num_add(&n2m1, &n2m1, &n2); // n2m1 = -R2 + R1 + R2 = R1 - CHECK(secp256k1_num_cmp(&n2m1, &n1) == 0); - CHECK(secp256k1_num_cmp(&n2p1, &n1) != 0); + CHECK(secp256k1_num_eq(&n2m1, &n1)); + CHECK(!secp256k1_num_eq(&n2p1, &n1)); secp256k1_num_sub(&n2p1, &n2p1, &n2); // n2p1 = R2 + R1 - R2 = R1 - CHECK(secp256k1_num_cmp(&n2p1, &n1) == 0); + CHECK(secp256k1_num_eq(&n2p1, &n1)); secp256k1_num_free(&n2m1); secp256k1_num_free(&n1m2); secp256k1_num_free(&n2p1); @@ -492,7 +496,7 @@ void test_wnaf(const secp256k1_num_t *number, int w) { secp256k1_num_set_int(&t, v); secp256k1_num_add(&x, &x, &t); } - CHECK(secp256k1_num_cmp(&x, number) == 0); // check that wnaf represents number + CHECK(secp256k1_num_eq(&x, number)); // check that wnaf represents number secp256k1_num_free(&x); secp256k1_num_free(&two); secp256k1_num_free(&t);