From 6518350ff499cf5535dcfed18c217f49fe6f64b9 Mon Sep 17 00:00:00 2001 From: mjalalzai <33738574+MForensic@users.noreply.github.com> Date: Sun, 11 Jun 2023 13:51:43 -0700 Subject: [PATCH] Merging bitarrays --- carnot/PoS_attestation.py | 50 ++++++++++++++++++++++++++++++++------ carnot/test_attestation.py | 17 +++++++++++++ 2 files changed, 59 insertions(+), 8 deletions(-) create mode 100644 carnot/test_attestation.py diff --git a/carnot/PoS_attestation.py b/carnot/PoS_attestation.py index ad27db8..7650ff6 100644 --- a/carnot/PoS_attestation.py +++ b/carnot/PoS_attestation.py @@ -1,6 +1,8 @@ -import random import zlib +import random +import hashlib +from typing import List from bitarray import bitarray @@ -8,7 +10,7 @@ from bitarray import bitarray # aggregation and verification. # A node receives bitarrays from its children, containing information on votes from its grand child committees. -def count_on_bitarray_fields(bitarrays, threshold, threshold2): +def count_on_bitarray_fields(bitarrays, majority_threshold, threshold2): assert all(len(bitarray) == len(bitarrays[0]) for bitarray in bitarrays), "All bit arrays must have the same length" assert all(sum(bitarray) >= threshold2 for bitarray in bitarrays), "Each bit array must have at least threshold2 number of 'on' bits" @@ -20,12 +22,14 @@ def count_on_bitarray_fields(bitarrays, threshold, threshold2): for i in range(array_size): count = sum(bitarray[i] for bitarray in bitarrays) - if count >= threshold: + if count >= majority_threshold: result[i] = 1 # or True return result + + bitarrays = [ [1, 0, 1, 0, 1], [0, 0, 1, 1, 1], @@ -39,18 +43,19 @@ result = count_on_bitarray_fields(bitarrays, threshold, threshold2) print(result) # Output: [1, 0, 1, 0, 1] -def getIndex(idSet, sender): - for index, voter in enumerate(idSet): - if sender == voter: +def getIndex(voteSet, sender): + for index, vote in enumerate(voteSet): + if sender == vote.voter: return index return -1 # Return -1 if the sender is not found in the idSet def createCommitteeBitArray(voters, committee_size): committee_bit_array = [False] * committee_size - + assert committee_size >= len(voters) for vote in voters: - sender = vote.sender + sender = vote.voter + print("voter is ", vote.voter) index = getIndex(voters, sender) if index >= 0 and index < committee_size: committee_bit_array[index] = True @@ -100,3 +105,32 @@ def decompressBitArray(compressed_data): +class Node: + def __init__(self, identifier, stake): + self.identifier = identifier + self.stake = stake + +def select_leader(nodes: List[Node], random_beacon: int) -> Node: + total_stake = sum(node.stake for node in nodes) + + # calculate weighted hash output for each node + weighted_hash_outputs = [] + for node in nodes: + hash_input = str(random_beacon) + str(node.identifier) + hash_output = int(hashlib.sha256(hash_input.encode()).hexdigest(), 16) + weighted_hash_output = hash_output * node.stake + weighted_hash_outputs.append(weighted_hash_output) + + # normalize weighted hash outputs to ensure that their sum is equal to total stake + normalized_weighted_hash_outputs = [x / sum(weighted_hash_outputs) * total_stake for x in weighted_hash_outputs] + + # select leader based on normalized weighted hash outputs + random_number = random.uniform(0, total_stake) + cumulative_weighted_hash_output = 0 + for i, node in enumerate(nodes): + cumulative_weighted_hash_output += normalized_weighted_hash_outputs[i] + if cumulative_weighted_hash_output >= random_number: + selected_leader = node + break + + return selected_leader diff --git a/carnot/test_attestation.py b/carnot/test_attestation.py new file mode 100644 index 0000000..73131b0 --- /dev/null +++ b/carnot/test_attestation.py @@ -0,0 +1,17 @@ +import unittest + +from PoS_attestation import * + + +class TestCountOnBitarrayFields(unittest.TestCase): + def test_count_on_bitarray_fields(self): + bitarrays = [[1, 0, 1], [0, 1, 1], [1, 1, 0]] + majority_threshold = 2 + threshold2 = 1 + + result = count_on_bitarray_fields(bitarrays, majority_threshold, threshold2) + expected_result = [1, 1, 1] + print("result is ", result) + self.assertEqual(result, expected_result) + +