diff --git a/cryptarchia/sync.py b/cryptarchia/sync.py index b7df20a..44e9cb3 100644 --- a/cryptarchia/sync.py +++ b/cryptarchia/sync.py @@ -12,40 +12,36 @@ from cryptarchia.cryptarchia import ( ) -def sync(local: Follower, peers: list[Follower]) -> bool: +def sync(local: Follower, peers: list[Follower]): # Syncs the local block tree with the peers, starting from the local tip. # This covers the case where the local tip is not on the latest honest chain anymore. - # - # The caller should call this function repeatedly until it returns True, - # which means that no peers have blocks ahead of the local tip. - # Fetch blocks from the peers in the range of slots from the local tip to the latest tip. - # Gather orphaned blocks, which are blocks from forks that are absent in the local block tree. - start_slot = local.tip().slot - orphans: set[BlockHeader] = set() - # Filter and group peers by their tip to minimize the number of fetches. - groups = filter_and_group_peers_by_tip(peers, start_slot) - if len(groups) == 0: - return True # No peers have blocks ahead of the local tip. - for group in groups.values(): - for block in fetch_blocks_by_slot(group, start_slot): - try: - local.on_block(block) - orphans.discard(block) - except ParentNotFound: - orphans.add(block) + # Repeat the sync process until no peer has a tip ahead of the local tip. + while True: + # Fetch blocks from the peers in the range of slots from the local tip to the latest tip. + # Gather orphaned blocks, which are blocks from forks that are absent in the local block tree. + start_slot = local.tip().slot + orphans: set[BlockHeader] = set() + # Filter and group peers by their tip to minimize the number of fetches. + groups = filter_and_group_peers_by_tip(peers, start_slot) + if len(groups) == 0: # No peer has a tip ahead of the local tip. + return - # Backfill the orphan forks starting from the orphan blocks with applying fork choice rule. - # - # Sort the orphan blocks by slot in descending order to minimize the number of backfillings. - for orphan in sorted(orphans, key=lambda b: b.slot, reverse=True): - # Skip the orphan block processed during the previous backfillings. - if orphan not in local.ledger_state: - backfill_fork(local, peers, orphan) + for group in groups.values(): + for block in fetch_blocks_by_slot(group, start_slot): + try: + local.on_block(block) + orphans.discard(block) + except ParentNotFound: + orphans.add(block) - # The caller should call this function again, - # assuming that peers' tips have been updated during the sync. - return False + # Backfill the orphan forks starting from the orphan blocks with applying fork choice rule. + # + # Sort the orphan blocks by slot in descending order to minimize the number of backfillings. + for orphan in sorted(orphans, key=lambda b: b.slot, reverse=True): + # Skip the orphan block processed during the previous backfillings. + if orphan not in local.ledger_state: + backfill_fork(local, peers, orphan) def filter_and_group_peers_by_tip( diff --git a/cryptarchia/test_sync.py b/cryptarchia/test_sync.py index 1c09635..b0ce4ad 100644 --- a/cryptarchia/test_sync.py +++ b/cryptarchia/test_sync.py @@ -22,10 +22,9 @@ class TestSync(TestCase): self.assertEqual(peer.forks, []) local = Follower(genesis, config) - self.assertFalse(sync(local, [peer])) + sync(local, [peer]) self.assertEqual(local.tip(), peer.tip()) self.assertEqual(local.forks, peer.forks) - self.assertTrue(sync(local, [peer])) def test_sync_single_chain_from_middle(self): # b0 - b1 - b2 - b3 @@ -47,10 +46,9 @@ class TestSync(TestCase): for b in [b0, b1]: peer.on_block(b) # start syncing from b1 - self.assertFalse(sync(local, [peer])) + sync(local, [peer]) self.assertEqual(local.tip(), peer.tip()) self.assertEqual(local.forks, peer.forks) - self.assertTrue(sync(local, [peer])) def test_sync_forks_from_genesis(self): # b0 - b1 - b2 - b5 == tip @@ -72,10 +70,9 @@ class TestSync(TestCase): self.assertEqual(peer.forks, [b4.id()]) local = Follower(genesis, config) - self.assertFalse(sync(local, [peer])) + sync(local, [peer]) self.assertEqual(local.tip(), peer.tip()) self.assertEqual(local.forks, peer.forks) - self.assertTrue(sync(local, [peer])) def test_sync_forks_from_middle(self): # b0 - b1 - b2 - b5 == tip @@ -102,10 +99,9 @@ class TestSync(TestCase): local = Follower(genesis, config) for b in [b0, b1, b3]: peer.on_block(b) - self.assertFalse(sync(local, [peer])) + sync(local, [peer]) self.assertEqual(local.tip(), peer.tip()) self.assertEqual(local.forks, peer.forks) - self.assertTrue(sync(local, [peer])) def test_sync_forks_by_backfilling(self): # b0 - b1 - b2 - b5 == tip @@ -131,11 +127,10 @@ class TestSync(TestCase): local = Follower(genesis, config) for b in [b0, b1]: peer.on_block(b) - self.assertFalse(sync(local, [peer])) + sync(local, [peer]) self.assertEqual(local.tip(), peer.tip()) self.assertEqual(local.forks, peer.forks) self.assertEqual(len(local.ledger_state), len(peer.ledger_state)) - self.assertTrue(sync(local, [peer])) def test_sync_multiple_peers_from_genesis(self): # Peer-0: b5 @@ -169,11 +164,10 @@ class TestSync(TestCase): self.assertEqual(peer2.forks, []) local = Follower(genesis, config) - self.assertFalse(sync(local, [peer0, peer1, peer2])) + sync(local, [peer0, peer1, peer2]) self.assertEqual(local.tip(), b5) self.assertEqual(local.forks, [b4.id()]) self.assertEqual(len(local.ledger_state), 7) - self.assertTrue(sync(local, [peer0, peer1, peer2])) class TestSyncFromCheckpoint(TestCase): @@ -201,7 +195,7 @@ class TestSyncFromCheckpoint(TestCase): checkpoint = peer.ledger_state[b2.id()] local = Follower(genesis, config) local.apply_checkpoint(checkpoint) - self.assertFalse(sync(local, [peer])) + sync(local, [peer]) # Result: # () - () - b2 - b3 # || @@ -211,7 +205,6 @@ class TestSyncFromCheckpoint(TestCase): self.assertEqual( set(local.ledger_state.keys()), set([genesis.block.id(), b2.id(), b3.id()]) ) - self.assertTrue(sync(local, [peer])) def test_sync_forks(self): # checkpoint @@ -241,7 +234,7 @@ class TestSyncFromCheckpoint(TestCase): checkpoint = peer.ledger_state[b2.id()] local = Follower(genesis, config) local.apply_checkpoint(checkpoint) - self.assertFalse(sync(local, [peer])) + sync(local, [peer]) # Result: # b0 - b1 - b2 - b5 == tip # \ @@ -249,7 +242,6 @@ class TestSyncFromCheckpoint(TestCase): self.assertEqual(local.tip(), peer.tip()) self.assertEqual(local.forks, peer.forks) self.assertEqual(set(local.ledger_state.keys()), set(peer.ledger_state.keys())) - self.assertTrue(sync(local, [peer])) def test_sync_from_dishonest_checkpoint(self): # Peer0: b0 - b1 - b2 - b5 == tip @@ -284,11 +276,10 @@ class TestSyncFromCheckpoint(TestCase): checkpoint = peer1.ledger_state[b4.id()] local = Follower(genesis, config) local.apply_checkpoint(checkpoint) - self.assertFalse(sync(local, [peer0, peer1])) + sync(local, [peer0, peer1]) # b0 - b1 - b2 - b5 == tip # \ # b3 - b4 self.assertEqual(local.tip(), b5) self.assertEqual(local.forks, [b4.id()]) self.assertEqual(len(local.ledger_state.keys()), 7) - self.assertTrue(sync(local, [peer0, peer1]))