Make easy tests

This commit is contained in:
danielsanchezq 2023-03-30 20:26:27 +02:00
parent 53d7efbb67
commit 5dedab1d2f
2 changed files with 90 additions and 22 deletions

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import TypeAlias, List, Set, Self, Optional from typing import TypeAlias, List, Set, Self, Optional, Dict
from abc import abstractmethod from abc import abstractmethod
Id: TypeAlias = bytes Id: TypeAlias = bytes
@ -7,6 +7,10 @@ View: TypeAlias = int
Committee: TypeAlias = Set[Id] Committee: TypeAlias = Set[Id]
def int_to_id(i: int) -> Id:
return bytes(str(i), encoding="utf8")
@dataclass @dataclass
class StandardQc: class StandardQc:
block: Id block: Id
@ -46,7 +50,7 @@ class Block:
return self.qc.block return self.qc.block
def id(self) -> Id: def id(self) -> Id:
return int.to_bytes(hash(self), length=32, byteorder="little") return int_to_id(hash((self.view, self.qc.view, self.qc.block)))
@dataclass @dataclass
@ -123,7 +127,7 @@ class Overlay:
pass pass
@abstractmethod @abstractmethod
def parent_committee(self, _id: Id) -> Option[Committee]: def parent_committee(self, _id: Id) -> Optional[Committee]:
""" """
:param _id: :param _id:
:return: Some(parent committee) of the participant with Id _id withing the committee tree overlay :return: Some(parent committee) of the participant with Id _id withing the committee tree overlay
@ -131,42 +135,45 @@ class Overlay:
""" """
pass pass
@abstractmethod
def super_majority_threshold(self, _id: Id) -> int:
"""
Amount of distinct number of messages for a node with Id _id member of a committee
The return value may change depending on which committee the node is member of, including the leader
:return:
"""
if self.is_leader(_id):
pass
elif self.member_of_root_committee(_id):
pass
else:
pass
def download(view) -> Block: def download(view) -> Block:
raise NotImplementedError raise NotImplementedError
def supermajority(votes: Set[Vote]) -> bool:
raise NotImplementedError
def leader_supermajorty(votes: Set[Vote]) -> bool:
raise NotImplementedError
def more_than_supermajority(votes: Set[Vote]) -> bool:
raise NotImplementedError
class Carnot: class Carnot:
def __init__(self, _id: Id): def __init__(self, _id: Id):
self.id: Id = _id self.id: Id = _id
self.current_view: View = 0 self.current_view: View = 0
self.local_high_qc: Optional[Qc] = None self.local_high_qc: Optional[Qc] = None
self.latest_committed_view: View = 0 self.latest_committed_view: View = 0
self.safe_blocks: Set[Id] = set() self.safe_blocks: Dict[Id, Block] = dict()
self.last_timeout_view_qc: Optional[TimeoutQc] = None self.last_timeout_view_qc: Optional[TimeoutQc] = None
self.last_timeout_view: Optional[View] = None self.last_timeout_view: Optional[View] = None
self.overlay: Overlay = Overlay() # TODO: integrate overlay self.overlay: Overlay = Overlay() # TODO: integrate overlay
self.committed_blocks: Dict[Id, Block] = dict()
def block_is_safe(self, block: Block) -> bool: def block_is_safe(self, block: Block) -> bool:
match block.qc: match block.qc:
case StandardQc() as standard: case StandardQc() as standard:
if standard.view <=self.latest_committed_view: if standard.view < self.latest_committed_view:
return False return False
return block.view >= self.latest_committed_view and block.view == (standard.view + 1) return block.view >= self.latest_committed_view and block.view == (standard.view + 1)
case AggregateQc() as aggregated: case AggregateQc() as aggregated:
if aggregated.high_qc().view <= self.latest_committed_view: if aggregated.high_qc().view < self.latest_committed_view:
return False return False
return block.view >= self.current_view return block.view >= self.current_view
@ -183,15 +190,21 @@ class Carnot:
def receive_block(self, block: Block): def receive_block(self, block: Block):
assert block.parent() in self.safe_blocks assert block.parent() in self.safe_blocks
assert block.id() in self.safe_blocks or block.view <= self.latest_committed_view
if block.qc.view < self.current_view:
return
if block.id() in self.safe_blocks or block.view <= self.latest_committed_view:
return
if self.block_is_safe(block): if self.block_is_safe(block):
self.safe_blocks.add(block.id()) self.safe_blocks[block.id()] = block
self.update_high_qc(block.qc) self.update_high_qc(block.qc)
self.try_commit_grand_parent(block)
self.increment_view_qc(block.qc)
def vote(self, block: Block, votes: Set[Vote]): def vote(self, block: Block, votes: Set[Vote]):
assert block.id() in self.safe_blocks assert block.id() in self.safe_blocks
assert supermajority(votes) assert len(votes) == self.overlay.super_majority_threshold(self.id)
assert all(self.overlay.child_committee(self.id, vote.voter) for vote in votes) assert all(self.overlay.child_committee(self.id, vote.voter) for vote in votes)
assert all(vote.block == block.id() for vote in votes) assert all(vote.block == block.id() for vote in votes)
@ -224,7 +237,7 @@ class Carnot:
def propose_block(self, view: View, quorum: Quorum): def propose_block(self, view: View, quorum: Quorum):
assert self.overlay.is_leader(self.id) assert self.overlay.is_leader(self.id)
assert leader_supermajorty(quorum) assert len(quorum) == self.overlay.super_majority_threshold(self.id)
qc = self.build_qc(quorum) qc = self.build_qc(quorum)
block = Block(view=view, qc=qc) block = Block(view=view, qc=qc)
@ -245,6 +258,28 @@ class Carnot:
def broadcast(self, block): def broadcast(self, block):
pass pass
def try_commit_grand_parent(self, block: Block):
parent = self.safe_blocks.get(block.parent())
grand_parent = self.safe_blocks.get(parent.parent())
# this case should just trigger on genesis_case,
# as the preconditions on outer calls should check on block validity
if not parent or not grand_parent:
return
can_commit = (
parent.view == (grand_parent.view + 1) and
isinstance(block.qc, (StandardQc, )) and
isinstance(parent.qc, (StandardQc, ))
)
if can_commit:
self.committed_blocks[block.id()] = block
def increment_view_qc(self, qc: Qc) -> bool:
if qc.view < self.current_view:
return False
self.last_timeout_view_qc = None
self.current_view = qc.view + 1
return True
if __name__ == "__main__": if __name__ == "__main__":
pass pass

33
carnot/test_happy_path.py Normal file
View File

@ -0,0 +1,33 @@
from .carnot import *
from unittest import TestCase
class TestCarnotHappyPath(TestCase):
@staticmethod
def add_genesis_block(carnot: Carnot) -> Block:
genesis_block = Block(view=0, qc=StandardQc(block=b"", view=0))
carnot.safe_blocks[genesis_block.id()] = genesis_block
carnot.committed_blocks[genesis_block.id()] = genesis_block
return genesis_block
def test_receive_block(self):
carnot = Carnot(int_to_id(0))
genesis_block = self.add_genesis_block(carnot)
block = Block(view=1, qc=StandardQc(block=genesis_block.id(), view=0))
carnot.receive_block(block)
def test_receive_block_has_old_qc(self):
carnot = Carnot(int_to_id(0))
genesis_block = self.add_genesis_block(carnot)
# 1
block1 = Block(view=1, qc=StandardQc(block=genesis_block.id(), view=0))
carnot.receive_block(block1)
# 2
block2 = Block(view=2, qc=StandardQc(block=block1.id(), view=1))
carnot.receive_block(block2)
# 3
block3 = Block(view=3, qc=StandardQc(block=block2.id(), view=2))
carnot.receive_block(block3)