diff --git a/cryptarchia/cryptarchia.py b/cryptarchia/cryptarchia.py index 443220b..a88af3c 100644 --- a/cryptarchia/cryptarchia.py +++ b/cryptarchia/cryptarchia.py @@ -536,8 +536,9 @@ class Follower: new_chain.blocks.append(block) # We may need to switch forks, lets run the fork choice rule to check. - new_chain = self.fork_choice() - if new_chain != self.local_chain: + new_chain_head = self.fork_choice() + if new_chain_head != self.local_chain.tip_id(): + assert new_chain_head == new_chain.tip_id() self.forks.remove(new_chain) self.forks.append(self.local_chain) self.local_chain = new_chain @@ -567,8 +568,8 @@ class Follower: # Evaluate the fork choice rule and return the chain we should be following def fork_choice(self) -> Chain: return maxvalid_bg( - self.local_chain, - self.forks, + self.local_chain.tip_id(), + [f.tip_id() for f in self.forks], self.ledger_state, k=self.config.k, s=self.config.s, @@ -766,16 +767,6 @@ def common_prefix_depth( assert False -def chain_density_old(chain: Chain, prefix_len: int, slot: Slot) -> int: - return len( - [ - block - for h, block in enumerate(chain.blocks) - if h >= prefix_len and block.slot.absolute_slot < slot.absolute_slot - ] - ) - - def chain_density( head: Id, slot: Slot, reorg_depth: int, states: Dict[Id, LedgerState] ) -> int: @@ -800,14 +791,12 @@ def maxvalid_bg( k: int, s: int, ) -> Chain: - # assert type(local_chain) == Id - # assert all(type(f) == Id for f in forks) + assert type(local_chain) == Id + assert all(type(f) == Id for f in forks) cmax = local_chain for fork in forks: - local_depth, fork_depth = common_prefix_depth( - cmax.tip_id(), fork.tip_id(), states - ) + local_depth, fork_depth = common_prefix_depth(cmax, fork, states) if local_depth <= k: # Classic longest chain rule with parameter k if local_depth < fork_depth: @@ -815,31 +804,18 @@ def maxvalid_bg( else: # The chain is forking too much, we need to pay a bit more attention # In particular, select the chain that is the densest after the fork - forking_block = local_chain.tip_id() + forking_block = local_chain for _ in range(local_depth): forking_block = states[forking_block].block.parent forking_slot = Slot(states[forking_block].block.slot.absolute_slot + s) - cmax_density = chain_density( - cmax.tip_id(), forking_slot, local_depth, states - ) - candidate_density = chain_density( - fork.tip_id(), forking_slot, fork_depth, states - ) - - prefix_len = cmax.length() - local_depth - - assert cmax_density == ( - d := chain_density_old(cmax, prefix_len, forking_slot) - ), f"{cmax_density} != {d}" - assert candidate_density == ( - d := chain_density_old(fork, prefix_len, forking_slot) - ), f"{candidate_density} != {d}" + cmax_density = chain_density(cmax, forking_slot, local_depth, states) + candidate_density = chain_density(fork, forking_slot, fork_depth, states) if cmax_density < candidate_density: cmax = fork - assert type(cmax) == Chain + assert type(cmax) == Id return cmax diff --git a/cryptarchia/test_fork_choice.py b/cryptarchia/test_fork_choice.py index ad4ad63..c8a6a79 100644 --- a/cryptarchia/test_fork_choice.py +++ b/cryptarchia/test_fork_choice.py @@ -100,11 +100,17 @@ class TestForkChoice(TestCase): b.id(): LedgerState(block=b) for b in short_chain.blocks + long_chain.blocks } - assert maxvalid_bg(short_chain, [long_chain], states, k, s) == short_chain + assert ( + maxvalid_bg(short_chain.tip_id(), [long_chain.tip_id()], states, k, s) + == short_chain.tip_id() + ) # However, if we set k to the fork length, it will be accepted k = long_chain.length() - assert maxvalid_bg(short_chain, [long_chain], states, k, s) == long_chain + assert ( + maxvalid_bg(short_chain.tip_id(), [long_chain.tip_id()], states, k, s) + == long_chain.tip_id() + ) def test_fork_choice_long_dense_chain(self): # The longest chain is also the densest after the fork @@ -133,7 +139,10 @@ class TestForkChoice(TestCase): b.id(): LedgerState(block=b) for b in short_chain.blocks + long_chain.blocks } - assert maxvalid_bg(short_chain, [long_chain], states, k, s) == long_chain + assert ( + maxvalid_bg(short_chain.tip_id(), [long_chain.tip_id()], states, k, s) + == long_chain.tip_id() + ) def test_fork_choice_integration(self): c_a, c_b = Coin(sk=0, value=10), Coin(sk=1, value=10)