153 lines
5.5 KiB
Python
153 lines
5.5 KiB
Python
import copy
|
|
import poly_utils
|
|
import rlp
|
|
|
|
try:
|
|
from Crypto.Hash import keccak
|
|
sha3 = lambda x: keccak.new(digest_bits=256, data=x).digest()
|
|
except ImportError:
|
|
import sha3 as _sha3
|
|
sha3 = lambda x: _sha3.sha3_256(x).digest()
|
|
|
|
# Every point is an element of GF(2**16), so represents two bytes
|
|
POINT_SIZE = 2
|
|
# Every chunk contains 128 points
|
|
POINTS_IN_CHUNK = 128
|
|
# A chunk is 256 bytes
|
|
CHUNK_SIZE = POINT_SIZE * POINTS_IN_CHUNK
|
|
|
|
def bytes_to_num(bytez):
|
|
o = 0
|
|
for b in bytez:
|
|
o = (o * 256) + b
|
|
return o
|
|
|
|
def num_to_bytes(inp, n):
|
|
o = b''
|
|
for i in range(n):
|
|
o = bytes([inp % 256]) + o
|
|
inp //= 256
|
|
return o
|
|
|
|
assert bytes_to_num(num_to_bytes(31337, 2)) == 31337
|
|
|
|
# Returns the smallest power of 2 equal to or greater than a number
|
|
def higher_power_of_2(x):
|
|
higher_power_of_2 = 1
|
|
while higher_power_of_2 < x:
|
|
higher_power_of_2 *= 2
|
|
return higher_power_of_2
|
|
|
|
# Unfortunately, most padding schemes standardized in cryptography seem to only work for
|
|
# block sizes strictly less than 256 bytes. So we'll use RLP plus zero byte padding
|
|
# instead (pre-RLP-encode because the RLP encoding adds length data, so the padding
|
|
# becomes reversible even in cases where the original data ends in zero bytes)
|
|
def pad(data):
|
|
med = rlp.encode(data)
|
|
return med + b'\x00' * (higher_power_of_2(len(med)) - len(med))
|
|
|
|
def unpad(data):
|
|
c, l1, l2 = rlp.codec.consume_length_prefix(data)
|
|
assert c == str
|
|
return data[:l1 + l2]
|
|
|
|
# Deserialize a chunk into a list of points in GF2**16
|
|
def chunk_to_points(chunk):
|
|
return [bytes_to_num(chunk[i: i + POINT_SIZE]) for i in range(0, CHUNK_SIZE, POINT_SIZE)]
|
|
|
|
# Serialize a list of points into a chunk
|
|
def points_to_chunk(points):
|
|
return b''.join([num_to_bytes(p, POINT_SIZE) for p in points])
|
|
|
|
testdata = sha3(b'cow') * (CHUNK_SIZE // 32)
|
|
assert points_to_chunk(chunk_to_points(testdata)) == testdata
|
|
|
|
# Make a Merkle tree out of a set of chunks
|
|
def merklize(chunks):
|
|
# Only accept a list of size which is exactly a power of two
|
|
assert higher_power_of_2(len(chunks)) == len(chunks)
|
|
merkle_nodes = [sha3(x) for x in chunks]
|
|
lower_tier = merkle_nodes[::]
|
|
higher_tier = []
|
|
while len(higher_tier) != 1:
|
|
higher_tier = [sha3(lower_tier[i] + lower_tier[i + 1]) for i in range(0, len(lower_tier), 2)]
|
|
merkle_nodes = higher_tier + merkle_nodes
|
|
lower_tier = higher_tier
|
|
merkle_nodes.insert(0, b'\x00' * 32)
|
|
return merkle_nodes
|
|
|
|
|
|
class Prover():
|
|
def __init__(self, data):
|
|
# Pad data
|
|
pdata = pad(data)
|
|
byte_chunks = [pdata[i: i + CHUNK_SIZE] for i in range(0, len(pdata), CHUNK_SIZE)]
|
|
# Decompose it into chunks, where each chunk is a collection of numbers
|
|
chunks = []
|
|
for byte_chunk in byte_chunks:
|
|
chunks.append(chunk_to_points(byte_chunk))
|
|
# Compute the polynomials representing the ith number in each chunk
|
|
polys = [poly_utils.lagrange_interp([chunk[i] for chunk in chunks], list(range(len(chunks)))) for i in range(POINTS_IN_CHUNK)]
|
|
# Use the polynomials to extend the chunks
|
|
new_chunks = []
|
|
for x in range(len(chunks), len(chunks) * 2):
|
|
new_chunks.append(points_to_chunk([poly_utils.eval_poly_at(poly, x) for poly in polys]))
|
|
# Total length of data including new points
|
|
self.length = len(byte_chunks + new_chunks)
|
|
self.extended_data = byte_chunks + new_chunks
|
|
# Build up the Merkle tree
|
|
self.merkle_nodes = merklize(self.extended_data)
|
|
assert len(self.merkle_nodes) == 2 * self.length
|
|
self.merkle_root = self.merkle_nodes[1]
|
|
|
|
# Make a Merkle proof for some index
|
|
def prove(self, index):
|
|
assert 0 <= index < self.length
|
|
adjusted_index = self.length + index
|
|
o = [self.extended_data[index]]
|
|
while adjusted_index > 1:
|
|
o.append(self.merkle_nodes[adjusted_index ^ 1])
|
|
adjusted_index >>= 1
|
|
return o
|
|
|
|
# Verify a merkle proof of some index (light client friendly)
|
|
def verify_proof(merkle_root, proof, index):
|
|
h = sha3(proof[0])
|
|
for p in proof[1:]:
|
|
if index % 2:
|
|
h = sha3(p + h)
|
|
else:
|
|
h = sha3(h + p)
|
|
index //= 2
|
|
return h == merkle_root
|
|
|
|
# Fill data from partially available proofs
|
|
# This method returning False can also be used as a verifier for fraud proofs
|
|
def fill(merkle_root, orig_data_length, proofs, indices):
|
|
if len(proofs) < orig_data_length:
|
|
raise Exception("Not enough proofs")
|
|
if len(proofs) > orig_data_length:
|
|
raise Exception("Too many proofs; if original data has n chunks, n chunks suffice")
|
|
for proof, index in zip(proofs, indices):
|
|
if not verify_proof(merkle_root, proof, index):
|
|
raise Exception("Merkle proof for index %d invalid" % index)
|
|
# Convert to points
|
|
coords = [chunk_to_points(p[0]) for p in proofs]
|
|
# Extract polynomials
|
|
polys = [poly_utils.lagrange_interp([c[i] for c in coords], indices) for i in range(POINTS_IN_CHUNK)]
|
|
# Fill in the remaining values
|
|
full_coords = [None] * orig_data_length * 2
|
|
for points, index in zip(coords, indices):
|
|
full_coords[index] = points
|
|
for i in range(len(full_coords)):
|
|
if full_coords[i] is None:
|
|
full_coords[i] = [poly_utils.eval_poly_at(poly, i) for poly in polys]
|
|
# Serialize
|
|
full_chunks = [points_to_chunk(points) for points in full_coords]
|
|
# Merklize
|
|
merkle_nodes = merklize(full_chunks)
|
|
# Check equality of the Merkle root
|
|
if merkle_root != merkle_nodes[1]:
|
|
return False
|
|
return full_chunks
|