diff --git a/eth/p2p/p2p_protocol_dsl.nim b/eth/p2p/p2p_protocol_dsl.nim index d002aae..bc1409f 100644 --- a/eth/p2p/p2p_protocol_dsl.nim +++ b/eth/p2p/p2p_protocol_dsl.nim @@ -1,5 +1,5 @@ import - options, + options, sequtils, stew/shims/macros, chronos, faststreams/output_stream type @@ -15,7 +15,8 @@ type kind*: MessageKind procDef*: NimNode timeoutParam*: NimNode - recIdent*: NimNode + recName*: NimNode + strongRecName*: NimNode recBody*: NimNode protocol*: P2PProtocol response*: Message @@ -346,7 +347,7 @@ proc hasReqId*(msg: Message): bool = proc ResponderType(msg: Message): NimNode = var resp = if msg.kind == msgRequest: msg.response else: msg newTree(nnkBracketExpr, - msg.protocol.backend.ResponderType, resp.recIdent) + msg.protocol.backend.ResponderType, resp.recName) proc newMsg(protocol: P2PProtocol, kind: MessageKind, id: int, procDef: NimNode, timeoutParam: NimNode = nil, @@ -361,18 +362,35 @@ proc newMsg(protocol: P2PProtocol, kind: MessageKind, id: int, msgName = $msgIdent recFields = newTree(nnkRecList) recBody = newTree(nnkObjectTy, newEmptyNode(), newEmptyNode(), recFields) - recName = ident(msgName & "Obj") + strongRecName = ident(msgName & "Obj") + recName = strongRecName for param, paramType in procDef.typedParams(skip = 1): recFields.add newTree(nnkIdentDefs, newTree(nnkPostfix, ident("*"), param), # The fields are public chooseFieldType(paramType), # some types such as openarray - # are automatically remapped - newEmptyNode()) + newEmptyNode()) # are automatically remapped - result = Message(protocol: protocol, id: id, ident: msgIdent, kind: kind, - procDef: procDef, recIdent: recName, recBody: recBody, - timeoutParam: timeoutParam, response: response) + if recFields.len == 1: + # When we have a single parameter, it's treated as the transferred message + # type. `recName` will be resolved to the message type that's intended + # for serialization while `strongRecName` will be a distinct type over + # which overloads such as `msgId` can be defined. We must use a distinct + # type because otherwise Nim may see multiple overloads defined over the + # same request parameter type and this will be an ambiguity error. + recName = recFields[0][1] + recBody = newTree(nnkDistinctTy, recName) + + result = Message(protocol: protocol, + id: id, + ident: msgIdent, + kind: kind, + procDef: procDef, + recName: recName, + strongRecName: strongRecName, + recBody: recBody, + timeoutParam: timeoutParam, + response: response) if procDef.body.kind != nnkEmpty: var userHandler = copy procDef @@ -482,9 +500,9 @@ proc createSendProc*(msg: Message, def[3][0] = if procType == nnkMacroDef: ident "untyped" elif msg.kind == msgRequest and not isRawSender: - Fut(Opt(msg.response.recIdent)) + Fut(Opt(msg.response.recName)) elif msg.kind == msgHandshake and not isRawSender: - Fut(msg.recIdent) + Fut(msg.recName) else: Fut(Void) @@ -518,13 +536,20 @@ proc writeParamsAsRecord*(params: openarray[NimNode], writer, recordWriterCtx, newLit($param), param) - result = quote do: - mixin init, writerType, beginRecord, endRecord + if params.len > 1: + result = quote do: + mixin init, writerType, beginRecord, endRecord - var `writer` = init(WriterType(`Format`), `outputStream`) - var `recordWriterCtx` = beginRecord(`writer`, `RecordType`) - `appendParams` - endRecord(`writer`, `recordWriterCtx`) + var `writer` = init(WriterType(`Format`), `outputStream`) + var `recordWriterCtx` = beginRecord(`writer`, `RecordType`) + `appendParams` + endRecord(`writer`, `recordWriterCtx`) + else: + let param = params[0] + + result = quote do: + var `writer` = init(WriterType(`Format`), `outputStream`) + writeValue(`writer`, `param`) proc useStandardBody*(sendProc: SendProc, preSerializationStep: proc(stream: NimNode): NimNode, @@ -537,7 +562,7 @@ proc useStandardBody*(sendProc: SendProc, initFuture = bindSym "initFuture" recipient = sendProc.peerParam - msgRecName = msg.recIdent + msgRecName = msg.recName Format = msg.protocol.backend.SerializationFormat preSerialization = if preSerializationStep.isNil: newStmtList() @@ -591,7 +616,7 @@ proc createSerializer*(msg: Message, procType = nnkProcDef): NimNode = serializer.msgParams, streamVar, msg.protocol.backend.SerializationFormat, - msg.recIdent) + msg.recName) return serializer.def @@ -612,10 +637,12 @@ proc genAwaitUserHandler*(msg: Message, receivedMsg: NimNode, var userHandlerCall = newCall(msg.userHandler.name, leadingParams) - for param, paramType in msg.procDef.typedParams(skip = 1): - # If there is user message handler, we'll place a call to it by - # unpacking the fields of the received message: - userHandlerCall.add newDotExpr(receivedMsg, param) + var params = toSeq(msg.procDef.typedParams(skip = 1)) + if params.len > 1: + for p in params: + userHandlerCall.add newDotExpr(receivedMsg, p[0]) + else: + userHandlerCall.add receivedMsg return newCall("await", userHandlerCall) @@ -646,7 +673,7 @@ proc createHandshakeTemplate*(msg: Message, peerValue = forwardCall[1] timeoutValue = msg.timeoutParam[0] peerVarSym = genSym(nskLet, "peer") - msgRecName = msg.recIdent + msgRecName = msg.recName forwardCall[1] = peerVarSym forwardCall.del(forwardCall.len - 1) @@ -771,12 +798,13 @@ proc genTypeSection*(p: P2PProtocol): NimNode = let msgId = msg.id msgName = msg.ident - msgRecName = msg.recIdent + msgRecName = msg.recName + msgStrongRecName = msg.strongRecName msgRecBody = msg.recBody result.add quote do: # This is a type featuring a single field for each message param: - type `msgRecName`* = `msgRecBody` + type `msgStrongRecName`* = `msgRecBody` # Add a helper template for accessing the message type: # e.g. p2p.hello: @@ -784,8 +812,8 @@ proc genTypeSection*(p: P2PProtocol): NimNode = # Add a helper template for obtaining the message Id for # a particular message type: - template msgId*(T: type `msgRecName`): int = `msgId` - template msgProtocol*(T: type `msgRecName`): type = `protocolName` + template msgId*(T: type `msgStrongRecName`): int = `msgId` + template msgProtocol*(T: type `msgStrongRecName`): type = `protocolName` proc genCode*(p: P2PProtocol): NimNode = # TODO: try switching to a simpler for msg in p.messages: loop diff --git a/eth/p2p/rlpx.nim b/eth/p2p/rlpx.nim index d64aa6f..9459866 100644 --- a/eth/p2p/rlpx.nim +++ b/eth/p2p/rlpx.nim @@ -609,9 +609,9 @@ proc p2pProtocolBackendImpl*(protocol: P2PProtocol): Backend = msgId = msg.id msgIdent = msg.ident msgName = $msgIdent - msgRecName = msg.recIdent + msgRecName = msg.recName responseMsgId = if msg.response != nil: msg.response.id else: -1 - ResponseRecord = if msg.response != nil: msg.response.recIdent else: nil + ResponseRecord = if msg.response != nil: msg.response.recName else: nil hasReqId = msg.hasReqId protocol = msg.protocol userPragmas = msg.procDef.pragma