implement the new nextMsg semantics

This commit is contained in:
Zahary Karadjov 2018-07-08 19:12:06 +03:00
parent 9057d18abe
commit f1001c45d2

View File

@ -47,7 +47,8 @@ type
connectionState: ConnectionState connectionState: ConnectionState
remote*: Node remote*: Node
protocolStates: seq[RootRef] protocolStates: seq[RootRef]
outstandingRequests*: seq[Deque[OutstandingRequest]] outstandingRequests: seq[Deque[OutstandingRequest]]
awaitedMessages: seq[FutureBase]
PeerPool* = ref object PeerPool* = ref object
keyPair: KeyPair keyPair: KeyPair
@ -62,14 +63,16 @@ type
MessageHandler* = proc(x: Peer, data: Rlp): Future[void] MessageHandler* = proc(x: Peer, data: Rlp): Future[void]
MessageContentPrinter* = proc(msg: pointer): string MessageContentPrinter* = proc(msg: pointer): string
MessageFutureResolver* = proc(msg: pointer, future: FutureBase) RequestResolver* = proc(msg: pointer, future: FutureBase)
NextMsgResolver* = proc(msgData: Rlp, future: FutureBase)
MessageInfo* = object MessageInfo* = object
id*: int id*: int
name*: string name*: string
thunk*: MessageHandler thunk*: MessageHandler
printer*: MessageContentPrinter printer*: MessageContentPrinter
futureResolver*: MessageFutureResolver requestResolver: RequestResolver
nextMsgResolver: NextMsgResolver
CapabilityName* = array[3, char] CapabilityName* = array[3, char]
@ -213,7 +216,11 @@ proc cmp*(lhs, rhs: ProtocolInfo): int {.inline.} =
proc messagePrinter[MsgType](msg: pointer): string = proc messagePrinter[MsgType](msg: pointer): string =
result = $(cast[ptr MsgType](msg)[]) result = $(cast[ptr MsgType](msg)[])
proc messageFutureResolver[MsgType](msg: pointer, future: FutureBase) = proc nextMsgResolver[MsgType](msgData: Rlp, future: FutureBase) =
var reader = msgData
Future[MsgType](future).complete reader.read(MsgType)
proc requestResolver[MsgType](msg: pointer, future: FutureBase) =
var f = Future[Option[MsgType]](future) var f = Future[Option[MsgType]](future)
if not f.finished: if not f.finished:
if msg != nil: if msg != nil:
@ -237,12 +244,14 @@ proc registerMsg(protocol: var ProtocolInfo,
id: int, name: string, id: int, name: string,
thunk: MessageHandler, thunk: MessageHandler,
printer: MessageContentPrinter, printer: MessageContentPrinter,
futureResolver: MessageFutureResolver) = requestResolver: RequestResolver,
nextMsgResolver: NextMsgResolver) =
protocol.messages.add MessageInfo(id: id, protocol.messages.add MessageInfo(id: id,
name: name, name: name,
thunk: thunk, thunk: thunk,
printer: printer, printer: printer,
futureResolver: futureResolver) requestResolver: requestResolver,
nextMsgResolver: nextMsgResolver)
proc registerProtocol(protocol: ProtocolInfo) = proc registerProtocol(protocol: ProtocolInfo) =
# TODO: This can be done at compile-time in the future # TODO: This can be done at compile-time in the future
@ -299,8 +308,8 @@ proc registerRequest(peer: Peer,
peer.outstandingRequests[responseMsgId].addLast req peer.outstandingRequests[responseMsgId].addLast req
# XXX: is this safe? # XXX: is this safe?
let futureResolver = peer.dispatcher.messages[responseMsgId].futureResolver let requestResolver = peer.dispatcher.messages[responseMsgId].requestResolver
proc timeoutExpired(udata: pointer) = futureResolver(nil, responseFuture) proc timeoutExpired(udata: pointer) = requestResolver(nil, responseFuture)
addTimer(timeoutAt, timeoutExpired, nil) addTimer(timeoutAt, timeoutExpired, nil)
@ -312,7 +321,7 @@ proc resolveResponseFuture(peer: Peer, msgId: int, msg: pointer, reqId: int) =
remotePeer = peer.remote remotePeer = peer.remote
template resolve(future) = template resolve(future) =
peer.dispatcher.messages[msgId].futureResolver(msg, future) peer.dispatcher.messages[msgId].requestResolver(msg, future)
template outstandingReqs: auto = template outstandingReqs: auto =
peer.outstandingRequests[msgId] peer.outstandingRequests[msgId]
@ -433,16 +442,27 @@ proc nextMsg*(peer: Peer, MsgType: typedesc): Future[MsgType] {.async.} =
## respective handlers. The designated message handler will also run ## respective handlers. The designated message handler will also run
## to completion before the future returned by `nextMsg` is resolved. ## to completion before the future returned by `nextMsg` is resolved.
const wantedId = MsgType.msgId const wantedId = MsgType.msgId
if peer.awaitedMessages[wantedId] != nil:
return Future[MsgType](peer.awaitedMessages[wantedId])
new result
peer.awaitedMessages[wantedId] = result
proc dispatchMessages*(peer: Peer) {.async.} =
while true: while true:
var (nextMsgId, nextMsgData) = await peer.recvMsg() var (msgId, msgData) = await peer.recvMsg()
# echo "got msg(", nextMsgId, "): ", nextMsgData.inspect
if nextMsgData.listLen != 0: # echo "got msg(", msgId, "): ", msgData.inspect
if msgData.listLen != 0:
# TODO: this should be `enterList` # TODO: this should be `enterList`
nextMsgData = nextMsgData.listElem(0) msgData = msgData.listElem(0)
await peer.dispatchMsg(nextMsgId, nextMsgData)
if nextMsgId == wantedId: await peer.dispatchMsg(msgId, msgData)
return nextMsgData.read(MsgType)
if peer.awaitedMessages[msgId] != nil:
let msgInfo = peer.dispatcher.messages[msgId]
msgInfo.nextMsgResolver(msgData, peer.awaitedMessages[msgId])
peer.awaitedMessages[msgId] = nil
iterator typedParams(n: NimNode, skip = 0): (NimNode, NimNode) = iterator typedParams(n: NimNode, skip = 0): (NimNode, NimNode) =
for i in (1 + skip) ..< n.params.len: for i in (1 + skip) ..< n.params.len:
@ -520,7 +540,8 @@ macro rlpxProtocol*(protoIdentifier: untyped,
networkStateType: NimNode = nil networkStateType: NimNode = nil
useRequestIds = true useRequestIds = true
messagePrinter = bindSym "messagePrinter" messagePrinter = bindSym "messagePrinter"
messageFutureResolver = bindSym "messageFutureResolver" requestResolver = bindSym "requestResolver"
nextMsgResolver = bindSym "nextMsgResolver"
# By convention, all Ethereum protocol names must be abbreviated to 3 letters # By convention, all Ethereum protocol names must be abbreviated to 3 letters
assert protoName.len == 3 assert protoName.len == 3
@ -736,7 +757,8 @@ macro rlpxProtocol*(protoIdentifier: untyped,
newStrLitNode($n.name), newStrLitNode($n.name),
thunkName, thunkName,
newTree(nnkBracketExpr, messagePrinter, msgRecord), newTree(nnkBracketExpr, messagePrinter, msgRecord),
newTree(nnkBracketExpr, messageFutureResolver, msgRecord))) newTree(nnkBracketExpr, requestResolver, msgRecord),
newTree(nnkBracketExpr, nextMsgResolver, msgRecord)))
result = finalOutput result = finalOutput
result.add quote do: result.add quote do:
@ -886,6 +908,10 @@ proc connectionEstablished(p: Peer, h: p2p.hello) =
p.outstandingRequests.newSeq(p.dispatcher.messages.len) p.outstandingRequests.newSeq(p.dispatcher.messages.len)
for d in mitems(p.outstandingRequests): d = initDeque[OutstandingRequest](0) for d in mitems(p.outstandingRequests): d = initDeque[OutstandingRequest](0)
# similarly, we need a bit of book-keeping data to keep track of the
# potentially concurrent calls to `nextMsg`.
p.awaitedMessages.newSeq(p.dispatcher.messages.len)
p.nextReqId = 1 p.nextReqId = 1
# p.id = h.nodeId # p.id = h.nodeId