implement the new nextMsg semantics

This commit is contained in:
Zahary Karadjov 2018-07-08 19:12:06 +03:00
parent 9057d18abe
commit f1001c45d2

View File

@ -47,7 +47,8 @@ type
connectionState: ConnectionState
remote*: Node
protocolStates: seq[RootRef]
outstandingRequests*: seq[Deque[OutstandingRequest]]
outstandingRequests: seq[Deque[OutstandingRequest]]
awaitedMessages: seq[FutureBase]
PeerPool* = ref object
keyPair: KeyPair
@ -62,14 +63,16 @@ type
MessageHandler* = proc(x: Peer, data: Rlp): Future[void]
MessageContentPrinter* = proc(msg: pointer): string
MessageFutureResolver* = proc(msg: pointer, future: FutureBase)
RequestResolver* = proc(msg: pointer, future: FutureBase)
NextMsgResolver* = proc(msgData: Rlp, future: FutureBase)
MessageInfo* = object
id*: int
name*: string
thunk*: MessageHandler
printer*: MessageContentPrinter
futureResolver*: MessageFutureResolver
requestResolver: RequestResolver
nextMsgResolver: NextMsgResolver
CapabilityName* = array[3, char]
@ -213,7 +216,11 @@ proc cmp*(lhs, rhs: ProtocolInfo): int {.inline.} =
proc messagePrinter[MsgType](msg: pointer): string =
result = $(cast[ptr MsgType](msg)[])
proc messageFutureResolver[MsgType](msg: pointer, future: FutureBase) =
proc nextMsgResolver[MsgType](msgData: Rlp, future: FutureBase) =
var reader = msgData
Future[MsgType](future).complete reader.read(MsgType)
proc requestResolver[MsgType](msg: pointer, future: FutureBase) =
var f = Future[Option[MsgType]](future)
if not f.finished:
if msg != nil:
@ -237,12 +244,14 @@ proc registerMsg(protocol: var ProtocolInfo,
id: int, name: string,
thunk: MessageHandler,
printer: MessageContentPrinter,
futureResolver: MessageFutureResolver) =
requestResolver: RequestResolver,
nextMsgResolver: NextMsgResolver) =
protocol.messages.add MessageInfo(id: id,
name: name,
thunk: thunk,
printer: printer,
futureResolver: futureResolver)
requestResolver: requestResolver,
nextMsgResolver: nextMsgResolver)
proc registerProtocol(protocol: ProtocolInfo) =
# TODO: This can be done at compile-time in the future
@ -299,8 +308,8 @@ proc registerRequest(peer: Peer,
peer.outstandingRequests[responseMsgId].addLast req
# XXX: is this safe?
let futureResolver = peer.dispatcher.messages[responseMsgId].futureResolver
proc timeoutExpired(udata: pointer) = futureResolver(nil, responseFuture)
let requestResolver = peer.dispatcher.messages[responseMsgId].requestResolver
proc timeoutExpired(udata: pointer) = requestResolver(nil, responseFuture)
addTimer(timeoutAt, timeoutExpired, nil)
@ -312,7 +321,7 @@ proc resolveResponseFuture(peer: Peer, msgId: int, msg: pointer, reqId: int) =
remotePeer = peer.remote
template resolve(future) =
peer.dispatcher.messages[msgId].futureResolver(msg, future)
peer.dispatcher.messages[msgId].requestResolver(msg, future)
template outstandingReqs: auto =
peer.outstandingRequests[msgId]
@ -433,16 +442,27 @@ proc nextMsg*(peer: Peer, MsgType: typedesc): Future[MsgType] {.async.} =
## respective handlers. The designated message handler will also run
## to completion before the future returned by `nextMsg` is resolved.
const wantedId = MsgType.msgId
if peer.awaitedMessages[wantedId] != nil:
return Future[MsgType](peer.awaitedMessages[wantedId])
new result
peer.awaitedMessages[wantedId] = result
proc dispatchMessages*(peer: Peer) {.async.} =
while true:
var (nextMsgId, nextMsgData) = await peer.recvMsg()
# echo "got msg(", nextMsgId, "): ", nextMsgData.inspect
if nextMsgData.listLen != 0:
var (msgId, msgData) = await peer.recvMsg()
# echo "got msg(", msgId, "): ", msgData.inspect
if msgData.listLen != 0:
# TODO: this should be `enterList`
nextMsgData = nextMsgData.listElem(0)
await peer.dispatchMsg(nextMsgId, nextMsgData)
if nextMsgId == wantedId:
return nextMsgData.read(MsgType)
msgData = msgData.listElem(0)
await peer.dispatchMsg(msgId, msgData)
if peer.awaitedMessages[msgId] != nil:
let msgInfo = peer.dispatcher.messages[msgId]
msgInfo.nextMsgResolver(msgData, peer.awaitedMessages[msgId])
peer.awaitedMessages[msgId] = nil
iterator typedParams(n: NimNode, skip = 0): (NimNode, NimNode) =
for i in (1 + skip) ..< n.params.len:
@ -520,7 +540,8 @@ macro rlpxProtocol*(protoIdentifier: untyped,
networkStateType: NimNode = nil
useRequestIds = true
messagePrinter = bindSym "messagePrinter"
messageFutureResolver = bindSym "messageFutureResolver"
requestResolver = bindSym "requestResolver"
nextMsgResolver = bindSym "nextMsgResolver"
# By convention, all Ethereum protocol names must be abbreviated to 3 letters
assert protoName.len == 3
@ -736,7 +757,8 @@ macro rlpxProtocol*(protoIdentifier: untyped,
newStrLitNode($n.name),
thunkName,
newTree(nnkBracketExpr, messagePrinter, msgRecord),
newTree(nnkBracketExpr, messageFutureResolver, msgRecord)))
newTree(nnkBracketExpr, requestResolver, msgRecord),
newTree(nnkBracketExpr, nextMsgResolver, msgRecord)))
result = finalOutput
result.add quote do:
@ -886,6 +908,10 @@ proc connectionEstablished(p: Peer, h: p2p.hello) =
p.outstandingRequests.newSeq(p.dispatcher.messages.len)
for d in mitems(p.outstandingRequests): d = initDeque[OutstandingRequest](0)
# similarly, we need a bit of book-keeping data to keep track of the
# potentially concurrent calls to `nextMsg`.
p.awaitedMessages.newSeq(p.dispatcher.messages.len)
p.nextReqId = 1
# p.id = h.nodeId