diff --git a/beacon_chain_impl/bls.py b/beacon_chain_impl/bls.py index b03c1ae..a55fee8 100644 --- a/beacon_chain_impl/bls.py +++ b/beacon_chain_impl/bls.py @@ -3,8 +3,8 @@ try: except: from pyblake2 import blake2s blake = lambda x: blake2s(x).digest() -from py_ecc.optimized_bn128 import G1, G2, add, multiply, FQ, FQ2, pairing, \ - normalize, field_modulus, b, b2, is_on_curve, curve_order +from py_ecc.optimized_bn128 import G1, G2, neg, add, multiply, FQ, FQ2, FQ12, pairing, \ + normalize, field_modulus, b, b2, is_on_curve, curve_order, final_exponentiate def compress_G1(pt): x, y = normalize(pt) @@ -34,7 +34,11 @@ def sqrt_fq2(x): y *= hex_root return y +cache = {} + def hash_to_G2(m): + if m in cache: + return cache[m] k2 = m while 1: k1 = blake(k2) @@ -46,7 +50,9 @@ def hash_to_G2(m): if xcb ** ((field_modulus ** 2 - 1) // 2) == FQ2([1,0]): break y = sqrt_fq2(xcb) - return multiply((x, y, FQ2([1,0])), 2*field_modulus-curve_order) + o = multiply((x, y, FQ2([1,0])), 2*field_modulus-curve_order) + cache[m] = o + return o def compress_G2(pt): assert is_on_curve(pt, b2) @@ -73,7 +79,8 @@ def privtopub(k): return compress_G1(multiply(G1, k)) def verify(m, pub, sig): - return pairing(decompress_G2(sig), G1) == pairing(hash_to_G2(m), decompress_G1(pub)) + return final_exponentiate(pairing(decompress_G2(sig), G1, False) * \ + pairing(hash_to_G2(m), neg(decompress_G1(pub)), False)) == FQ12.one() def aggregate_sigs(sigs): o = FQ2([1,0]), FQ2([1,0]), FQ2([0,0]) diff --git a/beacon_chain_impl/test_full_pos.py b/beacon_chain_impl/test_full_pos.py index e3e889a..9e75a87 100644 --- a/beacon_chain_impl/test_full_pos.py +++ b/beacon_chain_impl/test_full_pos.py @@ -4,8 +4,9 @@ from full_pos import blake, mk_genesis_state_and_block, compute_state_transition import random import bls from simpleserialize import serialize, deserialize, eq, deepcopy +import time -privkeys = [int.from_bytes(blake(str(i).encode('utf-8'))[:4], 'big') for i in range(500)] +privkeys = [int.from_bytes(blake(str(i).encode('utf-8'))[:4], 'big') for i in range(1000)] print('Generated privkeys') keymap = {} for i,k in enumerate(privkeys): @@ -80,7 +81,9 @@ print('Crystallized state length:', len(serialize(c))) print('Active state length:', len(serialize(a))) print('Block size:', len(serialize(block))) block2, c2, a2 = mock_make_child((c, a), block, 0, 0.8, []) +t = time.time() assert compute_state_transition((c, a), block, block2) +print("Normal block (basic attestation only) processed in %.4f sec" % (time.time() - t)) print('Verified a block!') block3, c3, a3 = mock_make_child((c2, a2), block2, 0, 0.8, [(0, 0.75)]) print('Verified a block with a committee!') @@ -89,3 +92,6 @@ while a3.height % SHARD_COUNT > 0: print('Height: %d' % a3.height) print('FFG bitmask:', bin(int.from_bytes(a3.ffg_voter_bitmask, 'big'))) block4, c4, a4 = mock_make_child((c3, a3), block3, 1, 0.55, []) +t = time.time() +assert compute_state_transition((c3, a3), block3, block4) +print("Epoch transition processed in %.4f sec" % (time.time() - t))