diff --git a/carnot/carnot_vote_aggregation.py b/carnot/carnot_vote_aggregation.py index 1d1c503..33c7c33 100644 --- a/carnot/carnot_vote_aggregation.py +++ b/carnot/carnot_vote_aggregation.py @@ -294,8 +294,10 @@ class Carnot2(Carnot): if concatenated_view is None: concatenated_view = qc.view - if highest_standard_qc is None or (isinstance(qc.highest_qc, StandardQc) and - qc.highest_qc.view > highest_standard_qc.view): + if highest_standard_qc is None or ( + isinstance(qc.highest_qc, StandardQc) and + qc.highest_qc.view > highest_standard_qc.view + ): highest_standard_qc = qc.highest_qc concatenated_aggregate_qc = AggregateQc( diff --git a/carnot/test_carnot_vote_aggregation.py b/carnot/test_carnot_vote_aggregation.py index 0cc2693..abff4c3 100644 --- a/carnot/test_carnot_vote_aggregation.py +++ b/carnot/test_carnot_vote_aggregation.py @@ -1,5 +1,7 @@ import unittest +from typing import Set, List + import carnot from carnot.carnot_vote_aggregation import AggregateQc from carnot.merging_committees import merge_committees @@ -50,6 +52,13 @@ class TestConcatenateStandardQcs(unittest.TestCase): class TestConcatenateAggregateQcs(unittest.TestCase): + def assertSetsEqual(self, set1: Set, set2: Set): + self.assertTrue(isinstance(set1, set)) + self.assertTrue(isinstance(set2, set)) + self.assertEqual(sorted(set1), sorted(set2)) + + def assertListsEqual(self, list1: List, list2: List): + self.assertEqual(sorted(list1), sorted(list2)) def test_concatenate_aggregate_qcs_single_qc(self): # Test concatenating a single AggregateQc @@ -66,31 +75,31 @@ class TestConcatenateAggregateQcs(unittest.TestCase): def test_concatenate_aggregate_qcs_multiple_qcs(self): # Test concatenating multiple AggregateQcs qc1 = AggregateQc( - sender_ids={1, 2, 3}, qcs=[1, 2, 3], - highest_qc=3, - view=1 + highest_qc=StandardQc(9, 9, {1, 2, 3, 4, 5, 6, 7, 8, 9}), + view=9, + sender_ids={1, 2, 3} ) qc2 = AggregateQc( - sender_ids={4, 5, 6}, qcs=[4, 5, 6], - highest_qc=6, - view=2 + highest_qc=StandardQc(6, 2, {4, 5, 6}), + view=2, + sender_ids={4, 5, 6} ) qc3 = AggregateQc( - sender_ids={7, 8, 9}, qcs=[7, 8, 9], - highest_qc=9, - view=3 + highest_qc=StandardQc(9, 3, {7, 8, 9}), + view=3, + sender_ids={7, 8, 9} ) aggregate_qcs = {qc1, qc2, qc3} concatenated_qc = carnot.carnot_vote_aggregation.Carnot2.concatenate_aggregate_qcs(aggregate_qcs) # Assert that the concatenated AggregateQc has the correct attributes - self.assertEqual(concatenated_qc.sender_ids, {1, 2, 3, 4, 5, 6, 7, 8, 9}) - self.assertEqual(concatenated_qc.qcs, [1, 2, 3, 4, 5, 6, 7, 8, 9]) - self.assertEqual(concatenated_qc.highest_qc, 9) - self.assertEqual(concatenated_qc.view, 1) # View should be from the first QC + self.assertSetsEqual(concatenated_qc.sender_ids, {1, 2, 3, 4, 5, 6, 7, 8, 9}) + self.assertListsEqual(sorted(concatenated_qc.qcs), [1, 2, 3, 4, 5, 6, 7, 8, 9]) + self.assertEqual(concatenated_qc.highest_qc.view, 9) + self.assertEqual(sorted(concatenated_qc.highest_qc.voters), [1, 2, 3, 4, 5, 6, 7, 8, 9]) if __name__ == '__main__':