From 9d2443162b3c00e4dcc1c82edd8a55c61dbdbb0a Mon Sep 17 00:00:00 2001 From: Vitalik Buterin Date: Mon, 30 Apr 2018 14:23:19 -0400 Subject: [PATCH] Added sharding support --- beacon/beacon_chain_node.py | 228 +++++++++++++++++++++++------------- beacon/test.py | 46 +++----- 2 files changed, 164 insertions(+), 110 deletions(-) diff --git a/beacon/beacon_chain_node.py b/beacon/beacon_chain_node.py index 79924a4..90487f0 100644 --- a/beacon/beacon_chain_node.py +++ b/beacon/beacon_chain_node.py @@ -25,6 +25,7 @@ SKIP_TS_DIFF = 6 SAMPLE = 9 MIN_SAMPLE = 5 POWDIFF = 30 * NOTARIES +SHARDS = 4 def checkpow(work, nonce): # Discrete log PoW, lolz @@ -53,7 +54,7 @@ class BeaconBlock(): self.ts = ts self.sigs = sigs self.number = parent.number + 1 if parent else 0 - self.main_chain_ref = main_chain_ref if main_chain_ref else parent.main_chain_ref + self.main_chain_ref = main_chain_ref.hash if main_chain_ref else parent.main_chain_ref if parent: i = parent.child_proposers.index(proposer) @@ -78,6 +79,13 @@ class BeaconBlock(): self.notaries.append(v % NOTARIES) v //= NOTARIES + # Calculate shard proposers + v = hash_to_int(sha3(self.contents + b':s')) + self.shard_proposers = [] + for i in range(SHARDS): + self.shard_proposers.append(v % NOTARIES) + v //= NOTARIES + class Sig(): def __init__(self, proposer, target): @@ -86,8 +94,26 @@ class Sig(): self.hash = os.urandom(32) assert self.proposer in target.notaries +class ShardCollation(): + def __init__(self, shard_id, parent, proposer, beacon_ref, ts): + self.proposer = proposer + self.parent_hash = parent.hash if parent else (bytes([40 + shard_id]) * 32) + self.hash = sha3(self.parent_hash + str(self.proposer).encode('utf-8') + beacon_ref.hash) + self.ts = ts + self.shard_id = shard_id + self.number = parent.number + 1 if parent else 0 + self.beacon_ref = beacon_ref.hash + + if parent: + assert self.shard_id == parent.shard_id + assert self.proposer == beacon_ref.shard_proposers[self.shard_id] + assert self.ts >= parent.ts + + assert self.ts >= beacon_ref.ts + main_genesis = MainChainBlock(None, 59049, 0) -beacon_genesis = BeaconBlock(None, 1, 0, [], main_genesis.hash) +beacon_genesis = BeaconBlock(None, 1, 0, [], main_genesis) +shard_geneses = [ShardCollation(i, None, 0, beacon_genesis, 0) for i in range(SHARDS)] class BlockMakingRequest(): def __init__(self, parent, ts): @@ -98,10 +124,16 @@ class BlockMakingRequest(): class Node(): def __init__(self, _id, network, sleepy=False, careless=False): - self.blocks = {beacon_genesis.hash: beacon_genesis, main_genesis.hash: main_genesis} + self.blocks = { + beacon_genesis.hash: beacon_genesis, + main_genesis.hash: main_genesis + } + for s in shard_geneses: + self.blocks[s.hash] = s self.sigs = {} - self.beacon_head = beacon_genesis.hash + self.beacon_chain = [beacon_genesis.hash] self.main_chain = [main_genesis.hash] + self.shard_chains = [[g.hash] for g in shard_geneses] self.timequeue = [] self.parentqueue = {} self.children = {} @@ -131,13 +163,15 @@ class Node(): self.processed[obj.hash] = True #self.log("Processing %s %s" % ("block" if isinstance(obj, BeaconBlock) else "sig", to_hex(obj.hash[:4]))) if isinstance(obj, BeaconBlock): - return self.on_receive_block(obj) + return self.on_receive_beacon_block(obj) elif isinstance(obj, MainChainBlock): return self.on_receive_main_block(obj) + elif isinstance(obj, ShardCollation): + return self.on_receive_shard_collation(obj) elif isinstance(obj, Sig): return self.on_receive_sig(obj) elif isinstance(obj, BlockMakingRequest): - if self.beacon_head == obj.parent: + if self.beacon_chain[-1] == obj.parent: mc_ref = self.blocks[obj.parent] for i in range(2): if mc_ref.number == 0: @@ -145,7 +179,7 @@ class Node(): #mc_ref = self.blocks[mc_ref].parent_hash x = BeaconBlock(self.blocks[obj.parent], self.id, self.ts, self.sigs[obj.parent] if obj.parent in self.sigs else [], - self.main_chain[-1]) + self.blocks[self.main_chain[-1]]) self.log("Broadcasting block %s" % to_hex(x.hash[:4])) self.broadcast(x) @@ -155,54 +189,65 @@ class Node(): i += 1 self.timequeue.insert(i, obj) + def add_to_multiset(self, _set, k, v): + if k not in _set: + _set[k] = [] + _set[k].append(v) + + def change_head(self, chain, new_head): + chain.extend([None] * (new_head.number + 1 - len(chain))) + i, c = new_head.number, new_head.hash + while c != chain[i]: + chain[i] = c + c = self.blocks[c].parent_hash + i -= 1 + for i in range(len(chain)): + assert self.blocks[chain[i]].number == i + + def recalculate_head(self, chain, condition): + while not condition(self.blocks[chain[-1]]): + chain.pop() + descendant_queue = [chain[-1]] + new_head = chain[-1] + while len(descendant_queue): + first = descendant_queue.pop(0) + if first in self.children: + for c in self.children[first]: + if condition(self.blocks[c]): + descendant_queue.append(c) + if self.blocks[first].number > self.blocks[new_head].number: + new_head = first + self.change_head(chain, self.blocks[new_head]) + for i in range(len(chain)): + assert condition(self.blocks[chain[i]]) + + def process_children(self, h): + if h in self.parentqueue: + for b in self.parentqueue[h]: + self.on_receive(b, reprocess=True) + del self.parentqueue[h] + def on_receive_main_block(self, block): # Parent not yet received if block.parent_hash not in self.blocks: - if block.parent_hash not in self.parentqueue: - self.parentqueue[block.parent_hash] = [] - self.parentqueue[block.parent_hash].append(block) + self.add_to_multiset(self.parentqueue, block.parent_hash, block) return None self.log("Processing main chain block %s" % to_hex(block.hash[:4])) self.blocks[block.hash] = block # Reorg the main chain if new head if block.number > self.blocks[self.main_chain[-1]].number: - assert block.number == len(self.main_chain), (block.number, self.blocks[self.main_chain[-1]].number) reorging = (block.parent_hash != self.main_chain[-1]) + self.change_head(self.main_chain, block) if reorging: - self.log("Reorging main chain", all=True) - i, c = block.number - 1, block.parent_hash - while c != self.main_chain[i]: - self.main_chain[i] = c - c = self.blocks[c].parent_hash - i -= 1 - self.main_chain.append(block.hash) - for i in range(len(self.main_chain)): - assert self.blocks[self.main_chain[i]].number == i - # Reorg the beacon - if reorging: - pre_beacon = self.beacon_head - while self.blocks[self.beacon_head].main_chain_ref not in self.main_chain: - self.beacon_head = self.blocks[self.beacon_head].parent_hash - descendant_queue = [self.beacon_head] - while len(descendant_queue): - first = descendant_queue.pop(0) - if first in self.children: - for c in self.children[first]: - if isinstance(self.blocks[c], BeaconBlock) and self.blocks[c].main_chain_ref in self.main_chain: - descendant_queue.append(c) - if self.blocks[first].number > self.blocks[self.beacon_head].number: - self.beacon_head = first - if self.beacon_head != pre_beacon: - self.log("Reorged beacon due to main chain reorg", all=True) + self.recalculate_head(self.beacon_chain, + lambda b: isinstance(b, BeaconBlock) and b.main_chain_ref in self.main_chain) + for i in range(SHARDS): + self.recalculate_head(self.shard_chains[i], + lambda b: isinstance(b, ShardCollation) and b.shard_id == i and b.beacon_ref in self.beacon_chain) # Add child record - if block.parent_hash not in self.children: - self.children[block.parent_hash] = [] - self.children[block.parent_hash].append(block.hash) - # Check for children - if block.hash in self.parentqueue: - for b in self.parentqueue[block.hash]: - self.on_receive(b, reprocess=True) - del self.parentqueue[block.hash] + self.add_to_multiset(self.children, block.parent_hash, block.hash) + # Final steps + self.process_children(block.hash) self.network.broadcast(self, block) def is_descendant(self, a, b): @@ -211,70 +256,67 @@ class Node(): b = self.blocks[b.parent_hash] return a.hash == b.hash - def on_receive_block(self, block): + def change_beacon_head(self, new_head): + self.log("Changed beacon head: %s" % new_head.number) + reorging = (new_head.parent_hash != self.beacon_chain[-1]) + self.change_head(self.beacon_chain, new_head) + if reorging: + for i in range(SHARDS): + self.recalculate_head(self.shard_chains[i], + lambda b: isinstance(b, ShardCollation) and b.shard_id == i and b.beacon_ref in self.beacon_chain) + # Produce shard collations? + for s in range(SHARDS): + if self.id == new_head.shard_proposers[s]: + sc = ShardCollation(s, self.blocks[self.shard_chains[s][-1]], self.id, new_head, self.ts) + assert sc.beacon_ref == new_head.hash + assert self.is_descendant(self.blocks[sc.parent_hash].beacon_ref, new_head.hash) + self.broadcast(sc) + for c in self.shard_chains[s]: + assert self.blocks[c].shard_id == s and self.blocks[c].beacon_ref in self.beacon_chain + + def on_receive_beacon_block(self, block): # Parent not yet received if block.parent_hash not in self.blocks: - if block.parent_hash not in self.parentqueue: - self.parentqueue[block.parent_hash] = [] - self.parentqueue[block.parent_hash].append(block) - return None + self.add_to_multiset(self.parentqueue, block.parent_hash, block) + return # Main chain parent not yet received if block.main_chain_ref not in self.blocks: - if block.main_chain_ref not in self.parentqueue: - self.parentqueue[block.main_chain_ref] = [] - self.parentqueue[block.main_chain_ref].append(block) - return None + self.add_to_multiset(self.parentqueue, block.main_chain_ref, block) + return # Too early if block.ts > self.ts: self.add_to_timequeue(block) return - assert block.parent_hash in self.blocks - assert block.main_chain_ref in self.blocks - assert self.blocks[block.parent_hash].main_chain_ref in self.blocks # Check consistency of cross-link reference assert self.is_descendant(self.blocks[block.parent_hash].main_chain_ref, block.main_chain_ref) # Add the block self.log("Processing beacon block %s" % to_hex(block.hash[:4])) self.blocks[block.hash] = block - # Am I a notary, and is the block building on the head? - # careless = I notarize even stuff not on the head - if block.parent_hash == self.beacon_head or self.careless: + # Am I a notary, and is the block building on the head? Then broadcast a signature. + if block.parent_hash == self.beacon_chain[-1] or self.careless: if self.id in block.notaries: - # Then broadcast a signature self.broadcast(Sig(self.id, block)) # Check for sigs, add to head?, make a block? - if (block.hash in self.sigs and len(self.sigs[block.hash]) >= block.notary_req) or block.notary_req == 0: - if block.number > self.blocks[self.beacon_head].number and block.main_chain_ref in self.main_chain: - self.log("Changed head: %s" % block.number) - self.beacon_head = block.hash + if len(self.sigs.get(block.hash, [])) >= block.notary_req: + if block.number > self.blocks[self.beacon_chain[-1]].number and block.main_chain_ref in self.main_chain: + self.change_beacon_head(block) if self.id in self.blocks[block.hash].child_proposers: my_index = self.blocks[block.hash].child_proposers.index(self.id) target_ts = block.ts + BASE_TS_DIFF + my_index * SKIP_TS_DIFF - self.log("Making block request for %.1f" % target_ts) self.add_to_timequeue(BlockMakingRequest(block.hash, target_ts)) # Add child record - if block.parent_hash not in self.children: - self.children[block.parent_hash] = [] - self.children[block.parent_hash].append(block.hash) - # Check for children - if block.hash in self.parentqueue: - for b in self.parentqueue[block.hash]: - self.on_receive(b, reprocess=True) - del self.parentqueue[block.hash] - # Rebroadcast + self.add_to_multiset(self.children, block.parent_hash, block.hash) + # Final steps + self.process_children(block.hash) self.network.broadcast(self, block) def on_receive_sig(self, sig): - self.log("Processing sig for %s" % to_hex(sig.target_hash[:4]), lvl=1) - if sig.target_hash not in self.sigs: - self.sigs[sig.target_hash] = [] - self.sigs[sig.target_hash].append(sig) + self.add_to_multiset(self.sigs, sig.target_hash, sig) # Add to head? Make a block? if sig.target_hash in self.blocks and len(self.sigs[sig.target_hash]) == self.blocks[sig.target_hash].notary_req: block = self.blocks[sig.target_hash] - if block.number > self.blocks[self.beacon_head].number and block.main_chain_ref in self.main_chain: - self.log("Changed head: %s" % block.number) - self.beacon_head = block.hash + if block.number > self.blocks[self.beacon_chain[-1]].number and block.main_chain_ref in self.main_chain: + self.change_beacon_head(block) if self.id in block.child_proposers: my_index = block.child_proposers.index(self.id) target_ts = block.ts + BASE_TS_DIFF + my_index * SKIP_TS_DIFF @@ -283,14 +325,38 @@ class Node(): # Rebroadcast self.network.broadcast(self, sig) + def on_receive_shard_collation(self, block): + # Parent not yet received + if block.parent_hash not in self.blocks: + self.add_to_multiset(self.parentqueue, block.parent_hash, block) + return None + # Beacon ref not yet received + if block.beacon_ref not in self.blocks: + self.add_to_multiset(self.parentqueue, block.beacon_ref, block) + return None + # Check consistency of cross-link reference + assert self.is_descendant(self.blocks[block.parent_hash].beacon_ref, block.beacon_ref) + self.log("Processing shard collation %s" % to_hex(block.hash[:4])) + self.blocks[block.hash] = block + # Set head if needed + if block.number > self.blocks[self.shard_chains[block.shard_id][-1]].number and block.beacon_ref in self.beacon_chain: + self.change_head(self.shard_chains[block.shard_id], block) + # Add child record + self.add_to_multiset(self.children, block.parent_hash, block.hash) + # Final steps + self.process_children(block.hash) + self.network.broadcast(self, block) + def tick(self): if self.ts == 0: if self.id in beacon_genesis.notaries: self.broadcast(Sig(self.id, beacon_genesis)) self.ts += 0.1 self.log("Tick: %.1f" % self.ts, lvl=1) + # Process time queue while len(self.timequeue) and self.timequeue[0].ts <= self.ts: self.on_receive(self.timequeue.pop(0)) + # Attempt to mine a main chain block pownonce = random.randrange(65537) mchead = self.blocks[self.main_chain[-1]] if checkpow(mchead.pownonce, pownonce): diff --git a/beacon/test.py b/beacon/test.py index cb112b3..abf7d5d 100644 --- a/beacon/test.py +++ b/beacon/test.py @@ -1,18 +1,20 @@ from networksim import NetworkSimulator -from beacon_chain_node import Node, NOTARIES, BeaconBlock, MainChainBlock, main_genesis, beacon_genesis +from beacon_chain_node import Node, NOTARIES, SHARDS, BeaconBlock, MainChainBlock, ShardCollation, main_genesis, beacon_genesis -net = NetworkSimulator(latency=15) +net = NetworkSimulator(latency=19) notaries = [Node(i, net, sleepy=i % 5 == 9) for i in range(NOTARIES)] net.agents = notaries net.generate_peers() for i in range(2000): net.tick() for n in notaries: - print("Beacon head: %d" % n.blocks[n.beacon_head].number) + print("Beacon head: %d" % n.blocks[n.beacon_chain[-1]].number) print("Main chain head: %d" % n.blocks[n.main_chain[-1]].number) + print("Shard heads: %r" % [n.blocks[x[-1]].number for x in n.shard_chains]) print("Total beacon blocks received: %d" % (len([b for b in n.blocks.values() if isinstance(b, BeaconBlock)]) - 1)) print("Total beacon blocks received and signed: %d" % (len([b for b in n.blocks.keys() if b in n.sigs and len(n.sigs[b]) >= n.blocks[b].notary_req]) - 1)) print("Total main chain blocks received: %d" % (len([b for b in n.blocks.values() if isinstance(b, MainChainBlock)]) - 1)) + print("Total shard blocks received: %r" % [len([b for b in n.blocks.values() if isinstance(b, ShardCollation) and b.shard_id == i]) - 1 for i in range(SHARDS)]) import matplotlib.pyplot as plt import networkx as nx @@ -28,34 +30,20 @@ for b in n.blocks.values(): if isinstance(b, BeaconBlock): G.add_edge(b.hash, b.main_chain_ref, color='g') G.add_edge(b.hash, b.parent_hash, color='y') - else: + elif isinstance(b, MainChainBlock): G.add_edge(b.hash, b.parent_hash, color='b') - -#G.add_edge('a','b',weight=1) -#G.add_edge('a','c',weight=1) -#G.add_edge('a','d',weight=1) -#G.add_edge('a','e',weight=1) -#G.add_edge('a','f',weight=1) -#G.add_edge('a','g',weight=1) + elif isinstance(b, ShardCollation): + G.add_edge(b.hash, b.beacon_ref, color='g') + G.add_edge(b.hash, b.parent_hash, color='r') -# pos=nx.spring_layout(G) -ypos={main_genesis.hash: 0, beacon_genesis.hash: 0} -queue = n.children[main_genesis.hash] + n.children[beacon_genesis.hash] -while len(queue): - first = queue.pop(0) - if isinstance(n.blocks[first], MainChainBlock): - if n.blocks[first].parent_hash not in ypos: - queue.append(first) - continue - ypos[first] = ypos[n.blocks[first].parent_hash] + 10 - elif isinstance(n.blocks[first], BeaconBlock): - if n.blocks[first].parent_hash not in ypos or n.blocks[first].main_chain_ref not in ypos: - queue.append(first) - continue - ypos[first] = max(ypos[n.blocks[first].parent_hash] + 1, ypos[n.blocks[first].main_chain_ref] + 1) - if first in n.children: - queue.extend(n.children[first]) -pos={b.hash: (b.ts + random.randrange(5) + (5 if isinstance(b, MainChainBlock) else 0), b.ts) for b in n.blocks.values()} +def mkoffset(b): + return random.randrange(5) + \ + (5 if isinstance(b, MainChainBlock) else + 0 if isinstance(b, BeaconBlock) else + -5 - 5 * b.shard_id if isinstance(b, ShardCollation) else + None) + +pos={b.hash: (b.ts + mkoffset(b), b.ts) for b in n.blocks.values()} edges = G.edges() colors = [G[u][v]['color'] for u,v in edges] nx.draw_networkx_nodes(G,pos,node_size=10,node_shape='o',node_color='0.75')