From c97dd58d76e401595c85cebbbcd8ba9835ccf6ab Mon Sep 17 00:00:00 2001 From: Csaba Kiraly Date: Wed, 25 Jan 2023 21:36:53 +0100 Subject: [PATCH] keep track of sent and received samples per neighbor Keeps track of sent and received samples per line per neighbor. Only send what wasn't yet sent or wasn't received from the other side. Signed-off-by: Csaba Kiraly --- DAS/simulator.py | 8 ++++---- DAS/validator.py | 34 +++++++++++++++++++++++++--------- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/DAS/simulator.py b/DAS/simulator.py index 6a369a6..f1ccf76 100644 --- a/DAS/simulator.py +++ b/DAS/simulator.py @@ -61,8 +61,8 @@ class Simulator: for u, v in G.edges: val1=rowChannels[id][u] val2=rowChannels[id][v] - val1.rowNeighbors[id].append(Neighbor(val2)) - val2.rowNeighbors[id].append(Neighbor(val1)) + val1.rowNeighbors[id].update({val2.ID : Neighbor(val2, self.shape.blockSize)}) + val2.rowNeighbors[id].update({val1.ID : Neighbor(val1, self.shape.blockSize)}) if (len(columnChannels[id]) < self.shape.netDegree): self.logger.error("Graph degree higher than %d" % len(columnChannels[id]), extra=self.format) @@ -72,8 +72,8 @@ class Simulator: for u, v in G.edges: val1=columnChannels[id][u] val2=columnChannels[id][v] - val1.columnNeighbors[id].append(Neighbor(val2)) - val2.columnNeighbors[id].append(Neighbor(val1)) + val1.columnNeighbors[id].update({val2.ID : Neighbor(val2, self.shape.blockSize)}) + val2.columnNeighbors[id].update({val1.ID : Neighbor(val1, self.shape.blockSize)}) def initLogger(self): logger = logging.getLogger("DAS") diff --git a/DAS/validator.py b/DAS/validator.py index c2c0124..67ee7e8 100644 --- a/DAS/validator.py +++ b/DAS/validator.py @@ -12,8 +12,10 @@ class Neighbor: def __repr__(self): return str(self.node.ID) - def __init__(self, v): + def __init__(self, v, blockSize): self.node = v + self.received = zeros(blockSize) + self.sent = zeros(blockSize) class Validator: @@ -49,8 +51,8 @@ class Validator: #self.columnIDs = random.sample(range(self.shape.blockSize), self.shape.chi) self.changedRow = {id:False for id in self.rowIDs} self.changedColumn = {id:False for id in self.columnIDs} - self.rowNeighbors = collections.defaultdict(list) - self.columnNeighbors = collections.defaultdict(list) + self.rowNeighbors = collections.defaultdict(dict) + self.columnNeighbors = collections.defaultdict(dict) def logIDs(self): if self.amIproposer == 1: @@ -93,14 +95,18 @@ class Validator: def getRow(self, index): return self.block.getRow(index) - def receiveColumn(self, id, column): + def receiveColumn(self, id, column, src): if id in self.columnIDs: + # register receive so that we are not sending back + self.columnNeighbors[id][src].received |= column self.receivedBlock.mergeColumn(id, column) else: pass - def receiveRow(self, id, row): + def receiveRow(self, id, row, src): if id in self.rowIDs: + # register receive so that we are not sending back + self.rowNeighbors[id][src].received |= row self.receivedBlock.mergeRow(id, row) else: pass @@ -129,15 +135,25 @@ class Validator: line = self.getColumn(columnID) if line.any(): self.logger.debug("col %d -> %s", columnID, self.columnNeighbors[columnID] , extra=self.format) - for n in self.columnNeighbors[columnID]: - n.node.receiveColumn(columnID, line) + for n in self.columnNeighbors[columnID].values(): + + # if there is anything new to send, send it + toSend = line & ~n.sent & ~n.received + if (toSend).any(): + n.sent |= toSend; + n.node.receiveColumn(columnID, toSend, self.ID) def sendRow(self, rowID): line = self.getRow(rowID) if line.any(): self.logger.debug("row %d -> %s", rowID, self.rowNeighbors[rowID], extra=self.format) - for n in self.rowNeighbors[rowID]: - n.node.receiveRow(rowID, line) + for n in self.rowNeighbors[rowID].values(): + + # if there is anything new to send, send it + toSend = line & ~n.sent & ~n.received + if (toSend).any(): + n.sent |= toSend; + n.node.receiveRow(rowID, toSend, self.ID) def sendRows(self): if self.amIproposer == 1: