Fix `dump_kzg_trusted_setup_files`. Use Fastest BLS lib (#3358)

This commit is contained in:
Hsiao-Wei Wang 2023-05-16 20:07:21 +08:00 committed by GitHub
parent 0e0b9ac00d
commit bb38c56ddd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 12 additions and 17 deletions

View File

@ -10,20 +10,13 @@ from typing import (
from pathlib import Path from pathlib import Path
from eth_utils import encode_hex from eth_utils import encode_hex
from py_ecc.optimized_bls12_381 import ( # noqa: F401
G1,
G2,
Z1,
Z2,
curve_order as BLS_MODULUS,
add,
multiply,
neg,
)
from py_ecc.typing import ( from py_ecc.typing import (
Optimized_Point3D, Optimized_Point3D,
) )
from eth2spec.utils import bls from eth2spec.utils import bls
from eth2spec.utils.bls import (
BLS_MODULUS,
)
PRIMITIVE_ROOT_OF_UNITY = 7 PRIMITIVE_ROOT_OF_UNITY = 7
@ -35,7 +28,7 @@ def generate_setup(generator: Optimized_Point3D, secret: int, length: int) -> Tu
""" """
result = [generator] result = [generator]
for _ in range(1, length): for _ in range(1, length):
result.append(multiply(result[-1], secret)) result.append(bls.multiply(result[-1], secret))
return tuple(result) return tuple(result)
@ -49,9 +42,9 @@ def fft(vals: Sequence[Optimized_Point3D], modulus: int, domain: int) -> Sequenc
R = fft(vals[1::2], modulus, domain[::2]) R = fft(vals[1::2], modulus, domain[::2])
o = [0] * len(vals) o = [0] * len(vals)
for i, (x, y) in enumerate(zip(L, R)): for i, (x, y) in enumerate(zip(L, R)):
y_times_root = multiply(y, domain[i]) y_times_root = bls.multiply(y, domain[i])
o[i] = add(x, y_times_root) o[i] = bls.add(x, y_times_root)
o[i + len(L)] = add(x, neg(y_times_root)) o[i + len(L)] = bls.add(x, bls.neg(y_times_root))
return o return o
@ -90,12 +83,14 @@ def get_lagrange(setup: Sequence[Optimized_Point3D]) -> Tuple[bytes]:
# TODO: introduce an IFFT function for simplicity # TODO: introduce an IFFT function for simplicity
fft_output = fft(setup, BLS_MODULUS, domain) fft_output = fft(setup, BLS_MODULUS, domain)
inv_length = pow(len(setup), BLS_MODULUS - 2, BLS_MODULUS) inv_length = pow(len(setup), BLS_MODULUS - 2, BLS_MODULUS)
return tuple(bls.G1_to_bytes48(multiply(fft_output[-i], inv_length)) for i in range(len(fft_output))) return tuple(bls.G1_to_bytes48(bls.multiply(fft_output[-i], inv_length)) for i in range(len(fft_output)))
def dump_kzg_trusted_setup_files(secret: int, g1_length: int, g2_length: int, output_dir: str) -> None: def dump_kzg_trusted_setup_files(secret: int, g1_length: int, g2_length: int, output_dir: str) -> None:
setup_g1 = generate_setup(bls.G1, secret, g1_length) bls.use_fastest()
setup_g2 = generate_setup(bls.G2, secret, g2_length)
setup_g1 = generate_setup(bls.G1(), secret, g1_length)
setup_g2 = generate_setup(bls.G2(), secret, g2_length)
setup_g1_lagrange = get_lagrange(setup_g1) setup_g1_lagrange = get_lagrange(setup_g1)
roots_of_unity = compute_roots_of_unity(g1_length) roots_of_unity = compute_roots_of_unity(g1_length)