153 lines
5.5 KiB
Python

import copy
import poly_utils
import rlp
try:
from Crypto.Hash import keccak
sha3 = lambda x: keccak.new(digest_bits=256, data=x).digest()
except ImportError:
import sha3 as _sha3
sha3 = lambda x: _sha3.sha3_256(x).digest()
# Every point is an element of GF(2**16), so represents two bytes
POINT_SIZE = 2
# Every chunk contains 128 points
POINTS_IN_CHUNK = 128
# A chunk is 256 bytes
CHUNK_SIZE = POINT_SIZE * POINTS_IN_CHUNK
def bytes_to_num(bytez):
o = 0
for b in bytez:
o = (o * 256) + b
return o
def num_to_bytes(inp, n):
o = b''
for i in range(n):
o = bytes([inp % 256]) + o
inp //= 256
return o
assert bytes_to_num(num_to_bytes(31337, 2)) == 31337
# Returns the smallest power of 2 equal to or greater than a number
def higher_power_of_2(x):
higher_power_of_2 = 1
while higher_power_of_2 < x:
higher_power_of_2 *= 2
return higher_power_of_2
# Unfortunately, most padding schemes standardized in cryptography seem to only work for
# block sizes strictly less than 256 bytes. So we'll use RLP plus zero byte padding
# instead (pre-RLP-encode because the RLP encoding adds length data, so the padding
# becomes reversible even in cases where the original data ends in zero bytes)
def pad(data):
med = rlp.encode(data)
return med + b'\x00' * (higher_power_of_2(len(med)) - len(med))
def unpad(data):
c, l1, l2 = rlp.codec.consume_length_prefix(data)
assert c == str
return data[:l1 + l2]
# Deserialize a chunk into a list of points in GF2**16
def chunk_to_points(chunk):
return [bytes_to_num(chunk[i: i + POINT_SIZE]) for i in range(0, CHUNK_SIZE, POINT_SIZE)]
# Serialize a list of points into a chunk
def points_to_chunk(points):
return b''.join([num_to_bytes(p, POINT_SIZE) for p in points])
testdata = sha3(b'cow') * (CHUNK_SIZE // 32)
assert points_to_chunk(chunk_to_points(testdata)) == testdata
# Make a Merkle tree out of a set of chunks
def merklize(chunks):
# Only accept a list of size which is exactly a power of two
assert higher_power_of_2(len(chunks)) == len(chunks)
merkle_nodes = [sha3(x) for x in chunks]
lower_tier = merkle_nodes[::]
higher_tier = []
while len(higher_tier) != 1:
higher_tier = [sha3(lower_tier[i] + lower_tier[i + 1]) for i in range(0, len(lower_tier), 2)]
merkle_nodes = higher_tier + merkle_nodes
lower_tier = higher_tier
merkle_nodes.insert(0, b'\x00' * 32)
return merkle_nodes
class Prover():
def __init__(self, data):
# Pad data
pdata = pad(data)
byte_chunks = [pdata[i: i + CHUNK_SIZE] for i in range(0, len(pdata), CHUNK_SIZE)]
# Decompose it into chunks, where each chunk is a collection of numbers
chunks = []
for byte_chunk in byte_chunks:
chunks.append(chunk_to_points(byte_chunk))
# Compute the polynomials representing the ith number in each chunk
polys = [poly_utils.lagrange_interp([chunk[i] for chunk in chunks], list(range(len(chunks)))) for i in range(POINTS_IN_CHUNK)]
# Use the polynomials to extend the chunks
new_chunks = []
for x in range(len(chunks), len(chunks) * 2):
new_chunks.append(points_to_chunk([poly_utils.eval_poly_at(poly, x) for poly in polys]))
# Total length of data including new points
self.length = len(byte_chunks + new_chunks)
self.extended_data = byte_chunks + new_chunks
# Build up the Merkle tree
self.merkle_nodes = merklize(self.extended_data)
assert len(self.merkle_nodes) == 2 * self.length
self.merkle_root = self.merkle_nodes[1]
# Make a Merkle proof for some index
def prove(self, index):
assert 0 <= index < self.length
adjusted_index = self.length + index
o = [self.extended_data[index]]
while adjusted_index > 1:
o.append(self.merkle_nodes[adjusted_index ^ 1])
adjusted_index >>= 1
return o
# Verify a merkle proof of some index (light client friendly)
def verify_proof(merkle_root, proof, index):
h = sha3(proof[0])
for p in proof[1:]:
if index % 2:
h = sha3(p + h)
else:
h = sha3(h + p)
index //= 2
return h == merkle_root
# Fill data from partially available proofs
# This method returning False can also be used as a verifier for fraud proofs
def fill(merkle_root, orig_data_length, proofs, indices):
if len(proofs) < orig_data_length:
raise Exception("Not enough proofs")
if len(proofs) > orig_data_length:
raise Exception("Too many proofs; if original data has n chunks, n chunks suffice")
for proof, index in zip(proofs, indices):
if not verify_proof(merkle_root, proof, index):
raise Exception("Merkle proof for index %d invalid" % index)
# Convert to points
coords = [chunk_to_points(p[0]) for p in proofs]
# Extract polynomials
polys = [poly_utils.lagrange_interp([c[i] for c in coords], indices) for i in range(POINTS_IN_CHUNK)]
# Fill in the remaining values
full_coords = [None] * orig_data_length * 2
for points, index in zip(coords, indices):
full_coords[index] = points
for i in range(len(full_coords)):
if full_coords[i] is None:
full_coords[i] = [poly_utils.eval_poly_at(poly, i) for poly in polys]
# Serialize
full_chunks = [points_to_chunk(points) for points in full_coords]
# Merklize
merkle_nodes = merklize(full_chunks)
# Check equality of the Merkle root
if merkle_root != merkle_nodes[1]:
return False
return full_chunks