Updated the RPJ algorithm

This commit is contained in:
Vitalik Buterin 2018-08-03 17:21:38 -04:00
parent 13a30dedf4
commit 82620d82de
2 changed files with 70 additions and 24 deletions

View File

@ -19,9 +19,9 @@ def hash_to_int(h):
o = (o << 8) + c o = (o << 8) + c
return o return o
NOTARIES = 20 NOTARIES = 60
SLOT_SIZE = 6 SLOT_SIZE = 6
EPOCH_LENGTH = 5 EPOCH_LENGTH = 20
# Not a full RANDAO; stub for now # Not a full RANDAO; stub for now
class Block(): class Block():
@ -38,11 +38,12 @@ class Block():
return SLOT_SIZE * self.slot return SLOT_SIZE * self.slot
class Sig(): class Sig():
def __init__(self, proposer, targets, ts): def __init__(self, proposer, targets, slot, ts):
self.proposer = proposer self.proposer = proposer
self.targets = targets self.targets = targets
self.hash = os.urandom(32) self.slot = slot
self.ts = ts self.ts = ts
self.hash = os.urandom(32)
genesis = Block(None, 0, 0) genesis = Block(None, 0, 0)
@ -58,6 +59,7 @@ class Node():
self.parentqueue = {} self.parentqueue = {}
self.children = {} self.children = {}
self.scores = {} self.scores = {}
self.scores_at_height = {}
self.justified = {} self.justified = {}
self.finalized = {} self.finalized = {}
self.ts = ts self.ts = ts
@ -68,8 +70,8 @@ class Node():
self.sleepy = sleepy self.sleepy = sleepy
self.careless = careless self.careless = careless
self.first_round = True self.first_round = True
self.last_made_block = -1 self.last_made_block = 0
self.last_made_sig = -1 self.last_made_sig = 0
def broadcast(self, x): def broadcast(self, x):
if self.sleepy and self.ts: if self.sleepy and self.ts:
@ -147,14 +149,6 @@ class Node():
b = self.blocks[b.parent_hash] b = self.blocks[b.parent_hash]
return a return a
def get_ancestor_at_slot(self, a, slot, strict=True):
while a.slot > slot and a.hash != genesis.hash:
a = self.blocks[a.parent_hash]
if a.slot == slot or strict is False:
return a
else:
return None
def is_descendant(self, a, b): def is_descendant(self, a, b):
a, b = self.blocks[a], self.blocks[b] a, b = self.blocks[a], self.blocks[b]
while b.height > a.height: while b.height > a.height:
@ -198,22 +192,70 @@ class Node():
max_score = max([0] + [self.scores.get(self.main_chain[i], 0) for i in range(anc.height + 1, len(self.main_chain))]) max_score = max([0] + [self.scores.get(self.main_chain[i], 0) for i in range(anc.height + 1, len(self.main_chain))])
# Process scoring # Process scoring
max_newchain_score = 0 max_newchain_score = 0
for c in sig.targets: for i, c in list(enumerate(sig.targets))[::-1]:
self.scores[c] = self.scores.get(c, 0) + 1 slot = sig.slot - 1 - i
if self.scores[c] == NOTARIES * 2 // 3: slot_key = slot.to_bytes(4, 'big')
assert self.blocks[c].slot <= slot
# If a parent and child block have non-consecutive slots, then the parent
# block is also considered to be the canonical block at all of the intermediate
# slot numbers. We store the scores for the block at each height separately
self.scores_at_height[slot_key + c] = self.scores_at_height.get(slot_key + c, 0) + 1
# For fork choice rule purposes, the score of a block is the highest score
# that it has at any height
self.scores[c] = max(self.scores.get(c, 0), self.scores_at_height[slot_key + c])
# If 2/3 of notaries vote for a block, it is justified
if self.scores_at_height[slot_key + c] == NOTARIES * 2 // 3:
self.justified[c] = True self.justified[c] = True
c_minus_one_epoch = self.get_ancestor_at_slot(self.blocks[c], self.blocks[c].slot - EPOCH_LENGTH) c2 = c
if c_minus_one_epoch and c_minus_one_epoch.hash in self.justified: self.log("Justified: %d %s" % (slot, hexlify(c).decode('utf-8')[:8]))
self.finalized[c_minus_one_epoch.hash] = True
# If EPOCH_LENGTH+1 blocks are justified in a row, the oldest is
# considered finalized
finalize = True
for slot2 in range(slot - 1, max(slot - EPOCH_LENGTH, 0) - 1, -1):
if slot2 < self.blocks[c2].slot:
c2 = self.blocks[c2].parent_hash
if self.scores_at_height.get(slot2.to_bytes(4, 'big') + c2, 0) < (NOTARIES * 2 // 3):
finalize = False
# self.log("Not quite finalized: stopped at %d needed %d" % (slot2, max(slot - EPOCH_LENGTH, 0)))
break
if finalize and c2 not in self.finalized:
self.log("Finalized: %d %s" % (self.blocks[c2].slot, hexlify(c).decode('utf-8')[:8]))
self.finalized[c2] = True
# Find the maximum score of a block on the chain that this sig is weighing on
if self.blocks[c].slot > anc.slot: if self.blocks[c].slot > anc.slot:
max_newchain_score = max(max_newchain_score, self.scores[c]) max_newchain_score = max(max_newchain_score, self.scores[c])
# If it's higher, switch over the canonical chain
if max_newchain_score > max_score: if max_newchain_score > max_score:
self.main_chain = self.main_chain[:anc.height+1] self.main_chain = self.main_chain[:anc.height+1]
self.recalculate_head() self.recalculate_head()
self.sigs[sig.hash] = sig self.sigs[sig.hash] = sig
# Rebroadcast # Rebroadcast
self.network.broadcast(self, sig) self.network.broadcast(self, sig)
# Get the portion of the main chain that is within the last EPOCH_LENGTH
# slots, once again duplicating the parent in cases where the parent and
# child's slots are not consecutive
def get_sig_targets(self, start_slot):
o = []
i = len(self.main_chain) - 1
for slot in range(start_slot - 1, max(start_slot - EPOCH_LENGTH, 0) - 1, -1):
if slot < self.blocks[self.main_chain[i]].slot:
i -= 1
o.append(self.main_chain[i])
for i, x in enumerate(o):
assert self.blocks[x].slot <= start_slot - 1 - i
assert len(o) == min(EPOCH_LENGTH, start_slot)
return o
def tick(self): def tick(self):
self.ts += 0.1 self.ts += 0.1
self.log("Tick: %.1f" % self.ts, lvl=1) self.log("Tick: %.1f" % self.ts, lvl=1)
@ -227,7 +269,9 @@ class Node():
sig_from = len(self.main_chain) - 1 sig_from = len(self.main_chain) - 1
while sig_from > 0 and self.blocks[self.main_chain[sig_from]].slot >= slot - EPOCH_LENGTH: while sig_from > 0 and self.blocks[self.main_chain[sig_from]].slot >= slot - EPOCH_LENGTH:
sig_from -= 1 sig_from -= 1
self.broadcast(Sig(self.id, self.main_chain[sig_from:][::-1], self.ts)) sig = Sig(self.id, self.get_sig_targets(slot), slot, self.ts)
# self.log('Sig:', self.id, sig.slot, ' '.join([hexlify(t).decode('utf-8')[:4] for t in sig.targets]))
self.broadcast(sig)
self.last_made_sig = slot self.last_made_sig = slot
# Process time queue # Process time queue
while len(self.timequeue) and self.timequeue[0].min_timestamp() <= self.ts: while len(self.timequeue) and self.timequeue[0].min_timestamp() <= self.ts:

View File

@ -2,11 +2,11 @@ from networksim import NetworkSimulator
from ghost_node import Node, NOTARIES, Block, Sig, genesis, SLOT_SIZE from ghost_node import Node, NOTARIES, Block, Sig, genesis, SLOT_SIZE
from distributions import normal_distribution from distributions import normal_distribution
net = NetworkSimulator(latency=300) net = NetworkSimulator(latency=150)
notaries = [Node(i, net, ts=max(normal_distribution(60, 60)(), 0) * 0.1, sleepy=i%4==0) for i in range(NOTARIES)] notaries = [Node(i, net, ts=max(normal_distribution(50, 50)(), 0) * 0.1, sleepy=False) for i in range(NOTARIES)]
net.agents = notaries net.agents = notaries
net.generate_peers() net.generate_peers()
for i in range(10000): for i in range(15000):
net.tick() net.tick()
for n in notaries: for n in notaries:
print("Local timestamp: %.1f, timequeue len %d" % (n.ts, len(n.timequeue))) print("Local timestamp: %.1f, timequeue len %d" % (n.ts, len(n.timequeue)))
@ -65,6 +65,8 @@ otheredges = [(u,v) for (u,v) in edges if G[u][v]['color'] == '0.75']
nx.draw_networkx_edges(G, pos, edgelist=otheredges, width=1, edge_color='0.75') nx.draw_networkx_edges(G, pos, edgelist=otheredges, width=1, edge_color='0.75')
nx.draw_networkx_edges(G, pos, edgelist=blockedges, width=2, edge_color='b') nx.draw_networkx_edges(G, pos, edgelist=blockedges, width=2, edge_color='b')
print('Scores:', [n.scores.get(c, 0) for c in n.main_chain])
plt.axis('off') plt.axis('off')
# plt.savefig("degree.png", bbox_inches="tight") # plt.savefig("degree.png", bbox_inches="tight")
plt.show() plt.show()