keep track of sent and received samples per neighbor

Keeps track of sent and received samples per line per neighbor.
Only send what wasn't yet sent or wasn't received from the other side.

Signed-off-by: Csaba Kiraly <csaba.kiraly@gmail.com>
This commit is contained in:
Csaba Kiraly 2023-01-25 21:36:53 +01:00
parent b38d8e13ae
commit c97dd58d76
No known key found for this signature in database
GPG Key ID: 0FE274EE8C95166E
2 changed files with 29 additions and 13 deletions

View File

@ -61,8 +61,8 @@ class Simulator:
for u, v in G.edges:
val1=rowChannels[id][u]
val2=rowChannels[id][v]
val1.rowNeighbors[id].append(Neighbor(val2))
val2.rowNeighbors[id].append(Neighbor(val1))
val1.rowNeighbors[id].update({val2.ID : Neighbor(val2, self.shape.blockSize)})
val2.rowNeighbors[id].update({val1.ID : Neighbor(val1, self.shape.blockSize)})
if (len(columnChannels[id]) < self.shape.netDegree):
self.logger.error("Graph degree higher than %d" % len(columnChannels[id]), extra=self.format)
@ -72,8 +72,8 @@ class Simulator:
for u, v in G.edges:
val1=columnChannels[id][u]
val2=columnChannels[id][v]
val1.columnNeighbors[id].append(Neighbor(val2))
val2.columnNeighbors[id].append(Neighbor(val1))
val1.columnNeighbors[id].update({val2.ID : Neighbor(val2, self.shape.blockSize)})
val2.columnNeighbors[id].update({val1.ID : Neighbor(val1, self.shape.blockSize)})
def initLogger(self):
logger = logging.getLogger("DAS")

View File

@ -12,8 +12,10 @@ class Neighbor:
def __repr__(self):
return str(self.node.ID)
def __init__(self, v):
def __init__(self, v, blockSize):
self.node = v
self.received = zeros(blockSize)
self.sent = zeros(blockSize)
class Validator:
@ -49,8 +51,8 @@ class Validator:
#self.columnIDs = random.sample(range(self.shape.blockSize), self.shape.chi)
self.changedRow = {id:False for id in self.rowIDs}
self.changedColumn = {id:False for id in self.columnIDs}
self.rowNeighbors = collections.defaultdict(list)
self.columnNeighbors = collections.defaultdict(list)
self.rowNeighbors = collections.defaultdict(dict)
self.columnNeighbors = collections.defaultdict(dict)
def logIDs(self):
if self.amIproposer == 1:
@ -93,14 +95,18 @@ class Validator:
def getRow(self, index):
return self.block.getRow(index)
def receiveColumn(self, id, column):
def receiveColumn(self, id, column, src):
if id in self.columnIDs:
# register receive so that we are not sending back
self.columnNeighbors[id][src].received |= column
self.receivedBlock.mergeColumn(id, column)
else:
pass
def receiveRow(self, id, row):
def receiveRow(self, id, row, src):
if id in self.rowIDs:
# register receive so that we are not sending back
self.rowNeighbors[id][src].received |= row
self.receivedBlock.mergeRow(id, row)
else:
pass
@ -129,15 +135,25 @@ class Validator:
line = self.getColumn(columnID)
if line.any():
self.logger.debug("col %d -> %s", columnID, self.columnNeighbors[columnID] , extra=self.format)
for n in self.columnNeighbors[columnID]:
n.node.receiveColumn(columnID, line)
for n in self.columnNeighbors[columnID].values():
# if there is anything new to send, send it
toSend = line & ~n.sent & ~n.received
if (toSend).any():
n.sent |= toSend;
n.node.receiveColumn(columnID, toSend, self.ID)
def sendRow(self, rowID):
line = self.getRow(rowID)
if line.any():
self.logger.debug("row %d -> %s", rowID, self.rowNeighbors[rowID], extra=self.format)
for n in self.rowNeighbors[rowID]:
n.node.receiveRow(rowID, line)
for n in self.rowNeighbors[rowID].values():
# if there is anything new to send, send it
toSend = line & ~n.sent & ~n.received
if (toSend).any():
n.sent |= toSend;
n.node.receiveRow(rowID, toSend, self.ID)
def sendRows(self):
if self.amIproposer == 1: