diff --git a/carnot/tree_overlay.py b/carnot/tree_overlay.py index ef180fd..361a309 100644 --- a/carnot/tree_overlay.py +++ b/carnot/tree_overlay.py @@ -8,9 +8,10 @@ import random class CarnotTree: def __init__(self, nodes: List[Id], number_of_committees: int): self.number_of_committees = number_of_committees - self.committee_size = len(nodes) // number_of_committees - self.inner_committees, self.membership_committees = CarnotTree.build_committee_from_nodes_with_size( - nodes, self.number_of_committees, self.committee_size + self.committee_size, self.inner_committees, self.membership_committees = ( + CarnotTree.build_committee_from_nodes_with_size( + nodes, self.number_of_committees + ) ) self.committees = {k: v for v, k in self.inner_committees.items()} self.nodes = CarnotTree.build_nodes_index(nodes, self.committee_size) @@ -24,21 +25,22 @@ class CarnotTree: def build_committee_from_nodes_with_size( nodes: List[Id], number_of_committees: int, - committee_size: int - ) -> Tuple[Dict[int, Id], Dict[int, Set[Id]]]: + ) -> Tuple[int, Dict[int, Id], Dict[int, Set[Id]]]: + committee_size, remainder = divmod(len(nodes), number_of_committees) committees = [ - # TODO: This hash method should be specific to what we would want to use for the protocol set(nodes[n*committee_size:(n+1)*committee_size]) for n in range(0, number_of_committees) ] - # TODO: for now simples solution is make latest committee bigger - remainder = len(nodes) % committee_size - remainder_nodes = set(nodes[-remainder:]) - committees[number_of_committees-1] |= remainder_nodes + # refill committees with extra nodes, + # we fill the leafs first as they are the least important + if remainder != 0: + cycling_committees = itertools.cycle(reversed(committees)) + for node in nodes[-remainder:]: + next(cycling_committees).add(node) committees = [frozenset(s) for s in committees] - + # TODO: This hash method should be specific to what we would want to use for the protocol hashes = [hash(s) for s in committees] - return dict(enumerate(hashes)), dict(enumerate(committees)) + return committee_size, dict(enumerate(hashes)), dict(enumerate(committees)) @staticmethod def build_nodes_index(nodes: List[Id], committee_size: int) -> Dict[Id, int]: