diff --git a/DAS/simulator.py b/DAS/simulator.py index 6a369a6..f1ccf76 100644 --- a/DAS/simulator.py +++ b/DAS/simulator.py @@ -61,8 +61,8 @@ class Simulator: for u, v in G.edges: val1=rowChannels[id][u] val2=rowChannels[id][v] - val1.rowNeighbors[id].append(Neighbor(val2)) - val2.rowNeighbors[id].append(Neighbor(val1)) + val1.rowNeighbors[id].update({val2.ID : Neighbor(val2, self.shape.blockSize)}) + val2.rowNeighbors[id].update({val1.ID : Neighbor(val1, self.shape.blockSize)}) if (len(columnChannels[id]) < self.shape.netDegree): self.logger.error("Graph degree higher than %d" % len(columnChannels[id]), extra=self.format) @@ -72,8 +72,8 @@ class Simulator: for u, v in G.edges: val1=columnChannels[id][u] val2=columnChannels[id][v] - val1.columnNeighbors[id].append(Neighbor(val2)) - val2.columnNeighbors[id].append(Neighbor(val1)) + val1.columnNeighbors[id].update({val2.ID : Neighbor(val2, self.shape.blockSize)}) + val2.columnNeighbors[id].update({val1.ID : Neighbor(val1, self.shape.blockSize)}) def initLogger(self): logger = logging.getLogger("DAS") diff --git a/DAS/validator.py b/DAS/validator.py index c2c0124..67ee7e8 100644 --- a/DAS/validator.py +++ b/DAS/validator.py @@ -12,8 +12,10 @@ class Neighbor: def __repr__(self): return str(self.node.ID) - def __init__(self, v): + def __init__(self, v, blockSize): self.node = v + self.received = zeros(blockSize) + self.sent = zeros(blockSize) class Validator: @@ -49,8 +51,8 @@ class Validator: #self.columnIDs = random.sample(range(self.shape.blockSize), self.shape.chi) self.changedRow = {id:False for id in self.rowIDs} self.changedColumn = {id:False for id in self.columnIDs} - self.rowNeighbors = collections.defaultdict(list) - self.columnNeighbors = collections.defaultdict(list) + self.rowNeighbors = collections.defaultdict(dict) + self.columnNeighbors = collections.defaultdict(dict) def logIDs(self): if self.amIproposer == 1: @@ -93,14 +95,18 @@ class Validator: def getRow(self, index): return self.block.getRow(index) - def receiveColumn(self, id, column): + def receiveColumn(self, id, column, src): if id in self.columnIDs: + # register receive so that we are not sending back + self.columnNeighbors[id][src].received |= column self.receivedBlock.mergeColumn(id, column) else: pass - def receiveRow(self, id, row): + def receiveRow(self, id, row, src): if id in self.rowIDs: + # register receive so that we are not sending back + self.rowNeighbors[id][src].received |= row self.receivedBlock.mergeRow(id, row) else: pass @@ -129,15 +135,25 @@ class Validator: line = self.getColumn(columnID) if line.any(): self.logger.debug("col %d -> %s", columnID, self.columnNeighbors[columnID] , extra=self.format) - for n in self.columnNeighbors[columnID]: - n.node.receiveColumn(columnID, line) + for n in self.columnNeighbors[columnID].values(): + + # if there is anything new to send, send it + toSend = line & ~n.sent & ~n.received + if (toSend).any(): + n.sent |= toSend; + n.node.receiveColumn(columnID, toSend, self.ID) def sendRow(self, rowID): line = self.getRow(rowID) if line.any(): self.logger.debug("row %d -> %s", rowID, self.rowNeighbors[rowID], extra=self.format) - for n in self.rowNeighbors[rowID]: - n.node.receiveRow(rowID, line) + for n in self.rowNeighbors[rowID].values(): + + # if there is anything new to send, send it + toSend = line & ~n.sent & ~n.received + if (toSend).any(): + n.sent |= toSend; + n.node.receiveRow(rowID, toSend, self.ID) def sendRows(self): if self.amIproposer == 1: