diff --git a/carnot/carnot_vote_aggregation.py b/carnot/carnot_vote_aggregation.py index 1cddacd..1d1c503 100644 --- a/carnot/carnot_vote_aggregation.py +++ b/carnot/carnot_vote_aggregation.py @@ -272,28 +272,32 @@ class Carnot2(Carnot): return concatenated_qc - # Similarly aggregated qcs are concatenated after timeout t2: - def concatenate_aggregate_qcs(qc_set: Set[AggregateQc]) -> AggregateQc: - # Initialize the attributes for the concatenated AggregateQc + # Similarly aggregated qcs are concatenated after timeout t2. + from typing import Set, List, Optional, Union + + # Define your types here (Id, View, StandardQc, AggregateQc, etc.) + + def concatenate_aggregate_qcs(qc_set: Set[Union[StandardQc, AggregateQc]]) -> AggregateQc: + if qc_set is None: + return None + concatenated_qcs = [] concatenated_view = None concatenated_sender_ids = set() highest_standard_qc = None - # Iterate through the input set of AggregateQc objects for qc in qc_set: - concatenated_qcs.extend(qc.qcs) - concatenated_sender_ids.update(qc.sender_ids) + if isinstance(qc, AggregateQc): + concatenated_qcs.extend(qc.qcs) + concatenated_sender_ids.update(qc.sender_ids) - # Choose the view value from the first AggregateQc in the set - if concatenated_view is None: - concatenated_view = qc.view + if concatenated_view is None: + concatenated_view = qc.view - # Find the highest StandardQc among the AggregateQc.high_qc.view fields - if highest_standard_qc is None or qc.high_qc.view > highest_standard_qc.view: - highest_standard_qc = qc.high_qc + if highest_standard_qc is None or (isinstance(qc.highest_qc, StandardQc) and + qc.highest_qc.view > highest_standard_qc.view): + highest_standard_qc = qc.highest_qc - # Create the concatenated AggregateQc object concatenated_aggregate_qc = AggregateQc( qcs=concatenated_qcs, highest_qc=highest_standard_qc, @@ -369,3 +373,5 @@ class Carnot2(Carnot): sender=self.id ) return Send(payload=timeout_msg, to=self.overlay.my_committee()) + + diff --git a/carnot/my_carnot.py b/carnot/my_carnot.py index ee72ddd..cfd9907 100644 --- a/carnot/my_carnot.py +++ b/carnot/my_carnot.py @@ -74,10 +74,11 @@ class StandardQc: @dataclass class AggregateQc: + sender_ids: Set[Id] qcs: List[View] highest_qc: StandardQc view: View - sender_ids: Set[Id] + def view(self) -> View: return self.view @@ -86,6 +87,9 @@ class AggregateQc: assert self.highest_qc.get_view == max(self.qcs) return self.highest_qc + def __hash__(self): + # Define a hash function based on the attributes that need to be considered for hashing + return hash((frozenset(self.sender_ids), tuple(self.qcs), self.highest_qc, self.view)) Qc: TypeAlias = StandardQc | AggregateQc diff --git a/carnot/test_carnot_vote_aggregation.py b/carnot/test_carnot_vote_aggregation.py index 22ad062..0cc2693 100644 --- a/carnot/test_carnot_vote_aggregation.py +++ b/carnot/test_carnot_vote_aggregation.py @@ -1,7 +1,7 @@ import unittest import carnot -from carnot import merging_committees +from carnot.carnot_vote_aggregation import AggregateQc from carnot.merging_committees import merge_committees import itertools import unittest @@ -49,5 +49,49 @@ class TestConcatenateStandardQcs(unittest.TestCase): self.assertEqual(concatenated_qc, expected_qc) +class TestConcatenateAggregateQcs(unittest.TestCase): + + def test_concatenate_aggregate_qcs_single_qc(self): + # Test concatenating a single AggregateQc + qc1 = AggregateQc( + sender_ids={1, 2, 3}, + qcs=[1, 2, 3], + highest_qc=3, + view=1 + ) + aggregate_qcs = {qc1} + concatenated_qc = carnot.carnot_vote_aggregation.Carnot2.concatenate_aggregate_qcs(aggregate_qcs) + self.assertEqual(concatenated_qc, qc1) + + def test_concatenate_aggregate_qcs_multiple_qcs(self): + # Test concatenating multiple AggregateQcs + qc1 = AggregateQc( + sender_ids={1, 2, 3}, + qcs=[1, 2, 3], + highest_qc=3, + view=1 + ) + qc2 = AggregateQc( + sender_ids={4, 5, 6}, + qcs=[4, 5, 6], + highest_qc=6, + view=2 + ) + qc3 = AggregateQc( + sender_ids={7, 8, 9}, + qcs=[7, 8, 9], + highest_qc=9, + view=3 + ) + aggregate_qcs = {qc1, qc2, qc3} + concatenated_qc = carnot.carnot_vote_aggregation.Carnot2.concatenate_aggregate_qcs(aggregate_qcs) + + # Assert that the concatenated AggregateQc has the correct attributes + self.assertEqual(concatenated_qc.sender_ids, {1, 2, 3, 4, 5, 6, 7, 8, 9}) + self.assertEqual(concatenated_qc.qcs, [1, 2, 3, 4, 5, 6, 7, 8, 9]) + self.assertEqual(concatenated_qc.highest_qc, 9) + self.assertEqual(concatenated_qc.view, 1) # View should be from the first QC + + if __name__ == '__main__': unittest.main()