diff --git a/cryptarchia/cryptarchia.py b/cryptarchia/cryptarchia.py index e5a3e45..05d7d07 100644 --- a/cryptarchia/cryptarchia.py +++ b/cryptarchia/cryptarchia.py @@ -1,12 +1,12 @@ -from typing import TypeAlias, List, Dict, Generator -from hashlib import sha256, blake2b -from math import floor -from copy import deepcopy -import itertools import functools -from dataclasses import dataclass, field, replace +import itertools import logging from collections import defaultdict +from copy import deepcopy +from dataclasses import dataclass, field, replace +from hashlib import blake2b, sha256 +from math import floor +from typing import Dict, Generator, List, TypeAlias import numpy as np @@ -338,10 +338,10 @@ class Follower: ): raise InvalidLeaderProof - def apply_block_to_ledger_state(self, block: BlockHeader) -> bool: + def on_block(self, block: BlockHeader): if block.id() in self.ledger_state: logger.warning("dropping already processed block") - return False + return self.validate_header(block) @@ -349,12 +349,6 @@ class Follower: new_state.apply(block) self.ledger_state[block.id()] = new_state - return True - - def on_block(self, block: BlockHeader): - if not self.apply_block_to_ledger_state(block): - return - if block.parent == self.local_chain: # simply extending the local chain self.local_chain = block.id() @@ -372,15 +366,6 @@ class Follower: self.forks.remove(new_tip) self.local_chain = new_tip - def apply_checkpoint(self, checkpoint: LedgerState): - checkpoint_block_id = checkpoint.block.id() - self.ledger_state[checkpoint_block_id] = checkpoint - if self.local_chain != self.genesis_state.block.id(): - self.forks.append(self.local_chain) - if checkpoint_block_id in self.forks: - self.forks.remove(checkpoint_block_id) - self.local_chain = checkpoint_block_id - # Evaluate the fork choice rule and return the chain we should be following def fork_choice(self) -> Hash: return maxvalid_bg( @@ -549,15 +534,9 @@ def iter_chain_blocks( def common_prefix_depth( a: Hash, b: Hash, states: Dict[Hash, LedgerState] ) -> tuple[int, list[BlockHeader], int, list[BlockHeader]]: - return common_prefix_depth_from_chains( - iter_chain_blocks(a, states), iter_chain_blocks(b, states) - ) + a_blocks = iter_chain_blocks(a, states) + b_blocks = iter_chain_blocks(b, states) - -def common_prefix_depth_from_chains( - a_blocks: Generator[BlockHeader, None, None], - b_blocks: Generator[BlockHeader, None, None], -) -> tuple[int, list[BlockHeader], int, list[BlockHeader]]: seen = {} a_suffix: list[BlockHeader] = [] b_suffix: list[BlockHeader] = [] diff --git a/cryptarchia/sync.py b/cryptarchia/sync.py index cea6f21..fee791b 100644 --- a/cryptarchia/sync.py +++ b/cryptarchia/sync.py @@ -5,20 +5,29 @@ from cryptarchia.cryptarchia import ( BlockHeader, Follower, Hash, + LedgerState, ParentNotFound, Slot, - common_prefix_depth_from_chains, iter_chain_blocks, ) -def sync(local: Follower, peers: list[Follower]): +def sync(local: Follower, peers: list[Follower], checkpoint: LedgerState | None = None): # Syncs the local block tree with the peers, starting from the local tip. # This covers the case where the local tip is not on the latest honest chain anymore. + block_fetcher = BlockFetcher(peers) + + # If the checkpoint is provided, backfill the checkpoint chain first + # before starting the sync process from the checkpoint block. + # If the backfilling fails, it means that the checkpoint chain is invalid. + # It is recommended to restart the sync process with a different checkpoint + # or without a checkpoint. + if checkpoint: + backfill_fork(local, checkpoint.block, block_fetcher) + # Repeat the sync process until no peer has a tip ahead of the local tip, # because peers' tips may advance during the sync process. - block_fetcher = BlockFetcher(peers) rejected_blocks: set[Hash] = set() while True: # Fetch blocks from the peers in the range of slots from the local tip to the latest tip. @@ -72,32 +81,31 @@ def backfill_fork( ): # Backfills a fork, which is absent in the local block tree, by fetching blocks from the peers. # During backfilling, the fork choice rule is continuously applied. - # - # If necessary, the local honest chain is also backfilled for the fork choice rule. - # This can happen if the honest chain has been built not from the genesis (i.e. checkpoint sync). - _, tip_suffix, _, fork_suffix = common_prefix_depth_from_chains( - block_fetcher.fetch_chain_backward(local.tip_id(), local), + suffix = find_disconnected_point( + local, block_fetcher.fetch_chain_backward(fork_tip.id(), local), ) - # First, backfill the local honest chain if some blocks are missing. - # In other words, backfill the local block tree, which contains the honest chain. - for block in tip_suffix: - try: - # Just apply the block to the ledger state is enough - # instead of calling `on_block` which runs the fork choice rule. - local.apply_block_to_ledger_state(block) - except Exception as e: - raise InvalidBlockTree(e) - - # Then, add blocks in the fork suffix with applying fork choice rule. + # Add blocks in the fork suffix with applying fork choice rule. # After all, add the tip of the fork suffix to apply the fork choice rule. - for i, block in enumerate(fork_suffix): + for i, block in enumerate(suffix): try: local.on_block(block) except Exception as e: - raise InvalidBlockFromBackfillFork(e, fork_suffix[i:]) + raise InvalidBlockFromBackfillFork(e, suffix[i:]) + + +def find_disconnected_point( + local: Follower, fork: Generator[BlockHeader, None, None] +) -> list[BlockHeader]: + suffix: list[BlockHeader] = [] + for block in fork: + if block.id() in local.ledger_state: + break + suffix.append(block) + suffix.reverse() + return suffix class BlockFetcher: @@ -168,12 +176,6 @@ class BlockFetcher: id = block.parent -class InvalidBlockTree(Exception): - def __init__(self, cause: Exception): - super().__init__() - self.cause = cause - - class InvalidBlockFromBackfillFork(Exception): def __init__(self, cause: Exception, invalid_suffix: list[BlockHeader]): super().__init__() diff --git a/cryptarchia/test_sync.py b/cryptarchia/test_sync.py index 2e869d6..cea2116 100644 --- a/cryptarchia/test_sync.py +++ b/cryptarchia/test_sync.py @@ -1,7 +1,7 @@ from unittest import TestCase -from cryptarchia.cryptarchia import BlockHeader, Note, Follower -from cryptarchia.sync import InvalidBlockTree, sync +from cryptarchia.cryptarchia import BlockHeader, Follower, Note +from cryptarchia.sync import InvalidBlockFromBackfillFork, sync from cryptarchia.test_common import mk_block, mk_chain, mk_config, mk_genesis_state @@ -288,17 +288,14 @@ class TestSyncFromCheckpoint(TestCase): # || # checkpoint # - # Result: A honest chain without historical blocks - # () - () - b2 - b3 + # Result: + # b0 - b1 - b2 - b3 checkpoint = peer.ledger_state[b2.id()] local = Follower(genesis, config) - local.apply_checkpoint(checkpoint) - sync(local, [peer]) + sync(local, [peer], checkpoint) self.assertEqual(local.tip(), peer.tip()) self.assertEqual(local.forks, peer.forks) - self.assertEqual( - set(local.ledger_state.keys()), set([genesis.block.id(), b2.id(), b3.id()]) - ) + self.assertEqual(local.ledger_state.keys(), peer.ledger_state.keys()) def test_sync_forks(self): # Prepare a peer with forks: @@ -331,8 +328,7 @@ class TestSyncFromCheckpoint(TestCase): # b3 - b4 checkpoint = peer.ledger_state[b2.id()] local = Follower(genesis, config) - local.apply_checkpoint(checkpoint) - sync(local, [peer]) + sync(local, [peer], checkpoint) self.assertEqual(local.tip(), peer.tip()) self.assertEqual(local.forks, peer.forks) self.assertEqual(set(local.ledger_state.keys()), set(peer.ledger_state.keys())) @@ -373,8 +369,7 @@ class TestSyncFromCheckpoint(TestCase): # b3 - b4 checkpoint = peer1.ledger_state[b4.id()] local = Follower(genesis, config) - local.apply_checkpoint(checkpoint) - sync(local, [peer0, peer1]) + sync(local, [peer0, peer1], checkpoint) self.assertEqual(local.tip(), b5) self.assertEqual(local.forks, [b4.id()]) self.assertEqual(len(local.ledger_state.keys()), 7) @@ -419,14 +414,13 @@ class TestSyncFromCheckpoint(TestCase): # b2 checkpoint = peer.ledger_state[b4.id()] local = Follower(genesis, config) - local.apply_checkpoint(checkpoint) - sync(local, [peer]) + sync(local, [peer], checkpoint) self.assertEqual(local.tip(), peer.tip()) self.assertEqual(local.forks, peer.forks) self.assertNotIn(b6.id(), local.ledger_state) self.assertNotIn(b7.id(), local.ledger_state) - def test_reject_invalid_blocks_from_backfilling_block_tree(self): + def test_reject_invalid_blocks_from_backfilling_checkpoint_chain(self): # Prepare a peer with invalid blocks in a fork: # b0 - b1 - b3 - b4 - b5 == tip # \ @@ -463,9 +457,8 @@ class TestSyncFromCheckpoint(TestCase): # Result: `InvalidBlockTree` exception checkpoint = peer.ledger_state[b7.id()] local = Follower(genesis, config) - local.apply_checkpoint(checkpoint) - with self.assertRaises(InvalidBlockTree): - sync(local, [peer]) + with self.assertRaises(InvalidBlockFromBackfillFork): + sync(local, [peer], checkpoint) def apply_invalid_block_to_ledger_state(follower: Follower, block: BlockHeader):