Merge pull request #3714 from kevaundray/kw/use-optimized-bls-msm

chore: use py-arkworks's multi-exp method inside of `g1_lincomb` and `g2_lincomb`
This commit is contained in:
Alex Stokes 2024-04-23 11:57:13 -06:00 committed by GitHub
commit b13e03e671
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 73 additions and 8 deletions

View File

@ -17,6 +17,7 @@ from eth2spec.capella import {preset_name} as capella
def preparations(cls):
return '''
T = TypeVar('T') # For generic function
TPoint = TypeVar('TPoint') # For generic function. G1 or G2 point.
'''
@classmethod

View File

@ -130,12 +130,18 @@ def coset_evals_to_cell(coset_evals: CosetEvals) -> Cell:
```python
def g2_lincomb(points: Sequence[G2Point], scalars: Sequence[BLSFieldElement]) -> Bytes96:
"""
BLS multiscalar multiplication in G2. This function can be optimized using Pippenger's algorithm and variants.
BLS multiscalar multiplication in G2. This can be naively implemented using double-and-add.
"""
assert len(points) == len(scalars)
result = bls.Z2()
for x, a in zip(points, scalars):
result = bls.add(result, bls.multiply(bls.bytes96_to_G2(x), a))
if len(points) == 0:
return bls.G2_to_bytes96(bls.Z2())
points_g2 = []
for point in points:
points_g2.append(bls.bytes96_to_G2(point))
result = bls.multi_exp(points_g2, scalars)
return Bytes96(bls.G2_to_bytes96(result))
```

View File

@ -18,6 +18,7 @@
- [`reverse_bits`](#reverse_bits)
- [`bit_reversal_permutation`](#bit_reversal_permutation)
- [BLS12-381 helpers](#bls12-381-helpers)
- [`multi_exp`](#multi_exp)
- [`hash_to_bls_field`](#hash_to_bls_field)
- [`bytes_to_bls_field`](#bytes_to_bls_field)
- [`bls_field_to_bytes`](#bls_field_to_bytes)
@ -146,6 +147,18 @@ def bit_reversal_permutation(sequence: Sequence[T]) -> Sequence[T]:
### BLS12-381 helpers
#### `multi_exp`
This function performs a multi-scalar multiplication between `points` and `integers`. `points` can either be in G1 or G2.
```python
def multi_exp(points: Sequence[TPoint],
integers: Sequence[uint64]) -> Sequence[TPoint]:
# pylint: disable=unused-argument
...
```
#### `hash_to_bls_field`
```python
@ -274,12 +287,18 @@ def div(x: BLSFieldElement, y: BLSFieldElement) -> BLSFieldElement:
```python
def g1_lincomb(points: Sequence[KZGCommitment], scalars: Sequence[BLSFieldElement]) -> KZGCommitment:
"""
BLS multiscalar multiplication. This function can be optimized using Pippenger's algorithm and variants.
BLS multiscalar multiplication in G1. This can be naively implemented using double-and-add.
"""
assert len(points) == len(scalars)
result = bls.Z1()
for x, a in zip(points, scalars):
result = bls.add(result, bls.multiply(bls.bytes48_to_G1(x), a))
if len(points) == 0:
return bls.G1_to_bytes48(bls.Z1())
points_g1 = []
for point in points:
points_g1.append(bls.bytes48_to_G1(point))
result = bls.multi_exp(points_g1, scalars)
return KZGCommitment(bls.G1_to_bytes48(result))
```

View File

@ -225,6 +225,45 @@ def multiply(point, scalar):
return py_ecc_mul(point, scalar)
def multi_exp(points, integers):
"""
Performs a multi-scalar multiplication between
`points` and `integers`.
`points` can either be in G1 or G2.
"""
# Since this method accepts either G1 or G2, we need to know
# the type of the point to return. Hence, we need at least one point.
if not points or not integers:
raise Exception("Cannot call multi_exp with zero points or zero integers")
if bls == arkworks_bls or bls == fastest_bls:
# Convert integers into arkworks Scalars
scalars = []
for integer in integers:
int_as_bytes = integer.to_bytes(32, 'little')
scalars.append(arkworks_Scalar.from_le_bytes(int_as_bytes))
# Check if we need to perform a G1 or G2 multiexp
if isinstance(points[0], arkworks_G1):
return arkworks_G1.multiexp_unchecked(points, scalars)
elif isinstance(points[0], arkworks_G2):
return arkworks_G2.multiexp_unchecked(points, scalars)
else:
raise Exception("Invalid point type")
result = None
if isinstance(points[0], py_ecc_G1):
result = Z1()
elif isinstance(points[0], py_ecc_G2):
result = Z2()
else:
raise Exception("Invalid point type")
for point, scalar in points.zip(integers):
result = add(result, multiply(point, scalar))
return result
def neg(point):
"""
Returns the point negation of `point`