diff --git a/cryptarchia/cryptarchia.py b/cryptarchia/cryptarchia.py index 05d7d07..a46ce89 100644 --- a/cryptarchia/cryptarchia.py +++ b/cryptarchia/cryptarchia.py @@ -1,5 +1,5 @@ import functools -import itertools +from itertools import islice import logging from collections import defaultdict from copy import deepcopy @@ -7,6 +7,7 @@ from dataclasses import dataclass, field, replace from hashlib import blake2b, sha256 from math import floor from typing import Dict, Generator, List, TypeAlias +from enum import Enum import numpy as np @@ -308,6 +309,9 @@ class EpochState: def nonce(self) -> bytes: return self.nonce_snapshot.nonce +class State(Enum): + ONLINE = 1 + BOOTSTRAPPING = 2 class Follower: def __init__(self, genesis_state: LedgerState, config: Config): @@ -317,12 +321,28 @@ class Follower: self.genesis_state = genesis_state self.ledger_state = {genesis_state.block.id(): genesis_state.copy()} self.epoch_state = {} + self.state = State.BOOTSTRAPPING + self.lib = genesis_state.block.id() # Last immutable block, initially the genesis block + + def to_online(self): + """ + Call this method when the follower has finished bootstrapping. While this is somewhat left to implementations + https://www.notion.so/Cryptarchia-v1-Bootstrapping-Synchronization-1fd261aa09df81ac94b5fb6a4eff32a6 contains a great deal + of information and is the reference for the Rust implementation. + """ + assert self.state == State.BOOTSTRAPPING, "Follower is not in BOOTSTRAPPING state" + self.state = State.ONLINE + self.update_lib() def validate_header(self, block: BlockHeader): # TODO: verify blocks are not in the 'future' if block.parent not in self.ledger_state: raise ParentNotFound + if height(block.parent, self.ledger_state) < height(self.lib, self.ledger_state): + # If the block is not a descendant of the last immutable block, we cannot process it. + raise ImmutableFork + current_state = self.ledger_state[block.parent].copy() epoch_state = self.compute_epoch_state( @@ -366,15 +386,51 @@ class Follower: self.forks.remove(new_tip) self.local_chain = new_tip + if self.state == State.ONLINE: + self.update_lib() + + + # Update the lib, and prune forks that do not descend from it. + def update_lib(self): + """ + Computes the last immutable block, which is the k-th block in the chain. + The last immutable block is the block that is guaranteed to be part of the chain + and will not be reverted. + """ + if self.state != State.ONLINE: + return + # prune forks that do not descend from the last immutable block, this is needed to avoid Genesis rule to roll back + # past the LIB + self.lib = next(islice(iter_chain(self.local_chain, self.ledger_state), self.config.k, None), self.local_chain).block.id() + self.forks = [ + f for f in self.forks if is_ancestor(self.lib, f, self.ledger_state) + ] + self.ledger_state = { + k: v + for k, v in self.ledger_state.items() + if is_ancestor(self.lib, k, self.ledger_state) or is_ancestor(k, self.lib, self.ledger_state) + } + + # Evaluate the fork choice rule and return the chain we should be following def fork_choice(self) -> Hash: - return maxvalid_bg( - self.local_chain, - self.forks, - k=self.config.k, - s=self.config.s, - states=self.ledger_state, - ) + if self.state == State.BOOTSTRAPPING: + return maxvalid_bg( + self.local_chain, + self.forks, + k=self.config.k, + s=self.config.s, + states=self.ledger_state, + ) + elif self.state == State.ONLINE: + return maxvalid_mc( + self.local_chain, + self.forks, + k=self.config.k, + states=self.ledger_state, + ) + else: + raise RuntimeError(f"Unknown follower state: {self.state}") def tip(self) -> BlockHeader: return self.tip_state().block @@ -515,6 +571,20 @@ class Leader: return ticket < Hash.ORDER * phi(self.config.active_slot_coeff, relative_stake) +def height(block: Hash, states: Dict[Hash, LedgerState]) -> int: + """ + Returns the height of the block in the chain, i.e. the number of blocks + between this block and the genesis block. + """ + if block not in states: + raise ValueError("State not found in states") + + height = 0 + while block in states: + height += 1 + block = states[block].block.parent + + return height def iter_chain( tip: Hash, states: Dict[Hash, LedgerState] @@ -530,6 +600,14 @@ def iter_chain_blocks( for state in iter_chain(tip, states): yield state.block +def is_ancestor(a: Hash, b: Hash, states: Dict[Hash, LedgerState]) -> bool: + """ + Returns True if `a` is an ancestor of `b` in the chain. + """ + for state in iter_chain(b, states): + if state.block.id() == a: + return True + return False def common_prefix_depth( a: Hash, b: Hash, states: Dict[Hash, LedgerState] @@ -592,7 +670,7 @@ def block_children(states: Dict[Hash, LedgerState]) -> Dict[Hash, set[Hash]]: return children -# Implementation of the Cryptarchia fork choice rule (following Ouroborous Genesis). +# Implementation of the Ouroboros Genesis fork choice rule. # The fork choice has two phases: # 1. if the chain is not forking too deeply, we apply the longest chain fork choice rule # 2. otherwise we look at the chain density immidiately following the fork @@ -633,6 +711,33 @@ def maxvalid_bg( return cmax +# Implementation of the Ouroboros Praos fork choice rule. +# The fork choice has two phases: +# 1. if the chain is not forking too deeply, we apply the longest chain fork choice rule +# 2. otherwise we discard the fork +# +# k defines the forking depth of a chain at which point we switch phases. +def maxvalid_mc( + local_chain: Hash, + forks: List[Hash], + k: int, + states: Dict[Hash, LedgerState], +) -> Hash: + assert type(local_chain) == Hash, type(local_chain) + assert all(type(f) == Hash for f in forks) + + cmax = local_chain + for fork in forks: + cmax_depth, _, fork_depth, _ = common_prefix_depth( + cmax, fork, states + ) + if cmax_depth <= k: + # Longest chain fork choice rule + if cmax_depth < fork_depth: + cmax = fork + + return cmax + class ParentNotFound(Exception): def __str__(self): return "Parent not found" @@ -642,6 +747,10 @@ class InvalidLeaderProof(Exception): def __str__(self): return "Invalid leader proof" +class ImmutableFork(Exception): + def __str__(self): + return "Block is forking deeper than the last immutable block" + if __name__ == "__main__": pass diff --git a/cryptarchia/test_fork_choice.py b/cryptarchia/test_fork_choice.py index 45f197b..c850f78 100644 --- a/cryptarchia/test_fork_choice.py +++ b/cryptarchia/test_fork_choice.py @@ -3,11 +3,14 @@ from unittest import TestCase from copy import deepcopy from cryptarchia.cryptarchia import ( maxvalid_bg, + maxvalid_mc, Slot, Note, + State, Follower, common_prefix_depth, LedgerState, + ImmutableFork, ) from .test_common import mk_chain, mk_config, mk_genesis_state, mk_block @@ -200,6 +203,11 @@ class TestForkChoice(TestCase): == short_chain[-1].id() ) + assert ( + maxvalid_mc(short_chain[-1].id(), [long_chain[-1].id()], k,states) + == short_chain[-1].id() + ) + # However, if we set k to the fork length, it will be accepted k = len(long_chain) assert ( @@ -207,6 +215,11 @@ class TestForkChoice(TestCase): == long_chain[-1].id() ) + assert ( + maxvalid_mc(short_chain[-1].id(), [long_chain[-1].id()], k, states) + == long_chain[-1].id() + ) + def test_fork_choice_long_dense_chain(self): # The longest chain is also the densest after the fork short_note, long_note = Note(sk=0, value=100), Note(sk=1, value=100) @@ -235,6 +248,13 @@ class TestForkChoice(TestCase): == long_chain[-1].id() ) + # praos fc rule should not accept a chain that diverged more than k blocks, + # even if it is longer + assert ( + maxvalid_mc(short_chain[-1].id(), [long_chain[-1].id()], k, states) + == short_chain[-1].id() + ) + def test_fork_choice_integration(self): n_a, n_b = Note(sk=0, value=10), Note(sk=1, value=10) notes = [n_a, n_b] @@ -281,3 +301,76 @@ class TestForkChoice(TestCase): assert follower.tip_id() == b4.id() assert len(follower.forks) == 1 and follower.forks[0] == b2.id(), follower.forks + + # -- switch to online mode -- + # + # b2 (does not descend from the LIB and is thus pruned) + # / + # b1 + # \ + # b3 (LIB) - b4 == tip + # + follower.to_online() + assert follower.lib == b3.id(), follower.lib + assert len(follower.forks) == 0, follower.forks + assert b2.id() not in follower.forks + + # -- extend a fork deeper than the LIB -- + # + # - - - - - - b5 + # / + # b1 + # \ + # b3 (LIB) - b4 == tip + # + b5 = mk_block(b1, 4, n_a) + with self.assertRaises(ImmutableFork): + follower.on_block(b5) + + # -- extend the main chain shallower than k -- + # + # b1 + # \ + # b3 - b4 (pruned) + # \ + # - - b7 (LIB) - b8 == tip + b7 = mk_block(b3, 4, n_b) + b8 = mk_block(b7, 5, n_b) + + follower.on_block(b7) + assert len(follower.forks) == 1 and b7.id() in follower.forks + + follower.on_block(b8) + assert follower.tip_id() == b8.id() + # b4 was pruned as it forks deeper than the LIB + assert len(follower.forks) == 0, follower.forks + + # Even in bootstrap mode, the follower should not accept blocks that fork deeper than k + follower.state = State.BOOTSTRAPPING + with self.assertRaises(ImmutableFork): + follower.on_block(b5) + + # But it should switch a chain diverging more than k as long as it + # descends from the LIB + # + # b1 + # \ + # b3 - - - - - - - b10 - b11 - b12 + # \ | + # - - b7 (LIB) - b8 - b9 == tip + b8 = mk_block(b7, 5, n_b) + b9 = mk_block(b8, 6, n_b) + b10 = mk_block(b7, 7, n_a) + b11 = mk_block(b10, 8, n_a) + b12 = mk_block(b11, 9, n_a) + follower.on_block(b8) + follower.on_block(b9) + + assert follower.tip_id() == b9.id() + + follower.on_block(b10) + follower.on_block(b11) + follower.on_block(b12) + + assert follower.tip_id() == b12.id() + assert follower.lib == b7.id(), follower.lib \ No newline at end of file