From 1467b145ae161818d505134d25bfb932f052beb7 Mon Sep 17 00:00:00 2001 From: Jacek Sieka Date: Tue, 5 Nov 2024 16:30:41 +0100 Subject: [PATCH] =?UTF-8?q?remove=20unusued=20rlpx=20features,=20tighten?= =?UTF-8?q?=20hello=20exchange=20and=20some=20error=20h=E2=80=A6=20(#759)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * disconnect peers that send non-hello messages during initial hello step * fix devp2p protocol version - 4 because we don't implement snappy (yet) - this is cosmetic since this particular version field is not actually being used * fix ack message length checking * move RLPx transport code to separate module, annotate with asyncraises * increase max RLPx message size to 16mb, per EIP-706 * make sure both accept/connect timeout after 10s * aim to log every connection attempt once at debug level * make capability-id/context-id check more accurate * disallow random messages before hello --- eth/p2p.nim | 2 +- eth/p2p/auth.nim | 22 +- eth/p2p/p2p_backends_helpers.nim | 2 +- eth/p2p/p2p_protocol_dsl.nim | 75 +--- eth/p2p/private/p2p_types.nim | 15 +- eth/p2p/rlpx.nim | 696 +++++++++++++------------------ eth/p2p/rlpxcrypt.nim | 31 +- eth/p2p/rlpxtransport.nim | 243 +++++++++++ tests/p2p/all_tests.nim | 1 + tests/p2p/test_crypt.nim | 15 +- tests/p2p/test_rlpxtransport.nim | 60 +++ tests/rlp/test_api_usage.nim | 3 +- 12 files changed, 633 insertions(+), 532 deletions(-) create mode 100644 eth/p2p/rlpxtransport.nim create mode 100644 tests/p2p/test_rlpxtransport.nim diff --git a/eth/p2p.nim b/eth/p2p.nim index 9d4592b..7db932a 100644 --- a/eth/p2p.nim +++ b/eth/p2p.nim @@ -27,7 +27,7 @@ proc addCapability*(node: EthereumNode, let pos = lowerBound(node.protocols, p, rlpx.cmp) node.protocols.insert(p, pos) - node.capabilities.insert(p.asCapability, pos) + node.capabilities.insert(p.capability, pos) if p.networkStateInitializer != nil and networkState.isNil: node.protocolStates[p.index] = p.networkStateInitializer(node) diff --git a/eth/p2p/auth.nim b/eth/p2p/auth.nim index 28f53c8..b89db6d 100644 --- a/eth/p2p/auth.nim +++ b/eth/p2p/auth.nim @@ -40,7 +40,7 @@ const ## least 100 bytes of padding to make the message distinguishable from ## pre-EIP8 and at most 200 to stay within recommendation - # signature + pubkey + nounce + version + rlp encoding overhead + # signature + pubkey + nonce + version + rlp encoding overhead # 65 + 64 + 32 + 1 + 7 = 169 PlainAuthMessageEIP8Length = 169 PlainAuthMessageMaxEIP8 = PlainAuthMessageEIP8Length + MaxPadLenEIP8 @@ -57,7 +57,8 @@ const PlainAckMessageEIP8Length = 102 PlainAckMessageMaxEIP8 = PlainAckMessageEIP8Length + MaxPadLenEIP8 # Min. encrypted message + size prefix = 217 - AckMessageEIP8Length* = eciesEncryptedLength(PlainAckMessageMaxEIP8) + MsgLenLenEIP8 + AckMessageEIP8Length* = + eciesEncryptedLength(PlainAckMessageEIP8Length) + MsgLenLenEIP8 AckMessageMaxEIP8* = AckMessageEIP8Length + MaxPadLenEIP8 ## Minimal output buffer size to pass into `ackMessage` @@ -225,18 +226,27 @@ proc ackMessage*( return err(AuthError.EciesError) ok(fullsize) -proc decodeMsgLen*(h: Handshake, input: openArray[byte]): AuthResult[int] = +func decodeMsgLen(input: openArray[byte]): AuthResult[int] = if input.len < 2: return err(AuthError.IncompleteError) - let len = int(uint16.fromBytesBE(input)) + 2 + ok(int(uint16.fromBytesBE(input)) + 2) + +func decodeAuthMsgLen*(h: Handshake, input: openArray[byte]): AuthResult[int] = + let len = ?decodeMsgLen(input) if len < AuthMessageEIP8Length: return err(AuthError.IncompleteError) ok(len) +func decodeAckMsgLen*(h: Handshake, input: openArray[byte]): AuthResult[int] = + let len = ?decodeMsgLen(input) + if len < AckMessageEIP8Length: + return err(AuthError.IncompleteError) + ok(len) + proc decodeAuthMessage*(h: var Handshake, m: openArray[byte]): AuthResult[void] = ## Decodes EIP-8 AuthMessage. let - expectedLength = ?h.decodeMsgLen(m) + expectedLength = ?h.decodeAuthMsgLen(m) size = expectedLength - MsgLenLenEIP8 # Check if the prefixed size is => than the minimum @@ -289,7 +299,7 @@ proc decodeAuthMessage*(h: var Handshake, m: openArray[byte]): AuthResult[void] proc decodeAckMessage*(h: var Handshake, m: openArray[byte]): AuthResult[void] = ## Decodes EIP-8 AckMessage. let - expectedLength = ?h.decodeMsgLen(m) + expectedLength = ?h.decodeAckMsgLen(m) size = expectedLength - MsgLenLenEIP8 # Check if the prefixed size is => than the minimum diff --git a/eth/p2p/p2p_backends_helpers.nim b/eth/p2p/p2p_backends_helpers.nim index bc0597e..b04b9b0 100644 --- a/eth/p2p/p2p_backends_helpers.nim +++ b/eth/p2p/p2p_backends_helpers.nim @@ -16,7 +16,7 @@ let protocolManager = ProtocolManager() proc registerProtocol*(proto: ProtocolInfo) {.gcsafe.} = {.gcsafe.}: proto.index = protocolManager.protocols.len - if proto.name == "p2p": + if proto.capability.name == "p2p": doAssert(proto.index == 0) protocolManager.protocols.add proto diff --git a/eth/p2p/p2p_protocol_dsl.nim b/eth/p2p/p2p_protocol_dsl.nim index ef3644f..398fb81 100644 --- a/eth/p2p/p2p_protocol_dsl.nim +++ b/eth/p2p/p2p_protocol_dsl.nim @@ -25,7 +25,7 @@ import std/[options, sequtils, macrocache], results, - stew/shims/macros, chronos, faststreams/outputs + stew/shims/macros, chronos type MessageKind* = enum @@ -699,77 +699,6 @@ proc writeParamsAsRecord*(params: openArray[NimNode], var `writer` = init(WriterType(`Format`), `outputStream`) writeValue(`writer`, `param`) -proc useStandardBody*(sendProc: SendProc, - preSerializationStep: proc(stream: NimNode): NimNode, - postSerializationStep: proc(stream: NimNode): NimNode, - sendCallGenerator: proc (peer, bytes: NimNode): NimNode) = - let - msg = sendProc.msg - msgBytes = ident "msgBytes" - recipient = sendProc.peerParam - sendCall = sendCallGenerator(recipient, msgBytes) - - if sendProc.msgParams.len == 0: - sendProc.setBody quote do: - var `msgBytes`: seq[byte] - `sendCall` - return - - let - outputStream = ident "outputStream" - - msgRecName = msg.recName - Format = msg.protocol.backend.SerializationFormat - - preSerialization = if preSerializationStep.isNil: newStmtList() - else: preSerializationStep(outputStream) - - serialization = writeParamsAsRecord(sendProc.msgParams, - outputStream, Format, msgRecName) - - postSerialization = if postSerializationStep.isNil: newStmtList() - else: postSerializationStep(outputStream) - - tracing = when not tracingEnabled: - newStmtList() - else: - logSentMsgFields(recipient, - msg.protocol.protocolInfo, - $msg.ident, - sendProc.msgParams) - - sendProc.setBody quote do: - mixin init, WriterType, beginRecord, endRecord, getOutput - - var `outputStream` = memoryOutput() - `preSerialization` - `serialization` - `postSerialization` - `tracing` - let `msgBytes` = getOutput(`outputStream`) - `sendCall` - -proc correctSerializerProcParams(params: NimNode) = - # A serializer proc is just like a send proc, but: - # 1. it has a void return type - params[0] = ident "void" - # 2. The peer params is replaced with OutputStream - params[1] = newIdentDefs(streamVar, bindSym "OutputStream") - # 3. The timeout param is removed - params.del(params.len - 1) - -proc createSerializer*(msg: Message, procType = nnkProcDef): NimNode = - var serializer = msg.createSendProc(procType, nameSuffix = "Serializer") - correctSerializerProcParams serializer.def.params - - serializer.setBody writeParamsAsRecord( - serializer.msgParams, - streamVar, - msg.protocol.backend.SerializationFormat, - msg.recName) - - return serializer.def - proc defineThunk*(msg: Message, thunk: NimNode) = let protocol = msg.protocol @@ -1019,7 +948,7 @@ proc genCode*(p: P2PProtocol): NimNode = regBody.add newCall(p.backend.registerProtocol, protocolVar) result.add quote do: - proc `protocolReg`() {.raises: [RlpError].} = + proc `protocolReg`() = let `protocolVar` = `protocolInit` `regBody` `protocolReg`() diff --git a/eth/p2p/private/p2p_types.nim b/eth/p2p/private/p2p_types.nim index abf7f58..7812f96 100644 --- a/eth/p2p/private/p2p_types.nim +++ b/eth/p2p/private/p2p_types.nim @@ -15,9 +15,9 @@ import chronos, results, ".."/../[rlp], ../../common/[base, keys], - ".."/[enode, kademlia, discovery, rlpxcrypt] + ".."/[enode, kademlia, discovery, rlpxtransport] -export base.NetworkId +export base.NetworkId, rlpxtransport const useSnappy* = defined(useSnappy) @@ -48,16 +48,16 @@ type network*: EthereumNode # Private fields: - transport*: StreamTransport + transport*: RlpxTransport dispatcher*: Dispatcher lastReqId*: Opt[uint64] - secretsState*: SecretState connectionState*: ConnectionState protocolStates*: seq[RootRef] outstandingRequests*: seq[Deque[OutstandingRequest]] # per `msgId` table awaitedMessages*: seq[FutureBase] # per `msgId` table when useSnappy: snappyEnabled*: bool + clientId*: string SeenNode* = object nodeId*: NodeId @@ -111,8 +111,7 @@ type protocols*: seq[ProtocolInfo] ProtocolInfo* = ref object - name*: string - version*: uint64 + capability*: Capability messages*: seq[MessageInfo] index*: int # the position of the protocol in the # ordered list of supported protocols @@ -209,12 +208,14 @@ type ClientQuitting = 0x08, UnexpectedIdentity = 0x09, SelfConnection = 0x0A, - MessageTimeout = 0x0B, + PingTimeout = 0x0B, SubprotocolReason = 0x10 Address = enode.Address proc `$`*(peer: Peer): string = $peer.remote +proc `$`*(v: Capability): string = v.name & "/" & $v.version + proc toENode*(v: EthereumNode): ENode = ENode(pubkey: v.keys.pubkey, address: v.address) diff --git a/eth/p2p/rlpx.nim b/eth/p2p/rlpx.nim index a27571f..e7d4d9c 100644 --- a/eth/p2p/rlpx.nim +++ b/eth/p2p/rlpx.nim @@ -25,17 +25,32 @@ {.push raises: [].} import - std/[algorithm, deques, options, typetraits, os], - stew/shims/macros, chronicles, nimcrypto/utils, chronos, metrics, + std/[algorithm, deques, options, os, sequtils, strutils, typetraits], + stew/shims/macros, chronicles, chronos, metrics, ".."/[rlp, async_utils], ./private/p2p_types, "."/[kademlia, auth, rlpxcrypt, enode, p2p_protocol_dsl] +const + devp2pVersion* = 4 + connectionTimeout = 10.seconds + + msgIdHello = byte 0 + msgIdDisconnect = byte 1 + msgIdPing = byte 2 + msgIdPong = byte 3 + # TODO: This doesn't get enabled currently in any of the builds, so we send a # devp2p protocol handshake message with version. Need to check if some peers # drop us because of this. when useSnappy: import snappy - const devp2pSnappyVersion* = 5 + const + devp2pSnappyVersion* = 5 + # The maximum message size is normally limited by the 24-bit length field in + # the message header but in the case of snappy, we need to protect against + # decompression bombs: + # https://eips.ethereum.org/EIPS/eip-706#avoiding-dos-attacks + maxMsgSize = 1024 * 1024 * 16 # TODO: chronicles re-export here is added for the error # "undeclared identifier: 'activeChroniclesStream'", when the code using p2p @@ -106,11 +121,6 @@ proc read(rlp: var Rlp; T: type DisconnectionReasonList): T raise newException(RlpTypeMismatch, "Single entry list expected") -const - devp2pVersion* = 4 - maxMsgSize = 1024 * 1024 * 10 - HandshakeTimeout = MessageTimeout - include p2p_tracing when tracingEnabled: @@ -184,6 +194,9 @@ proc messagePrinter[MsgType](msg: pointer): string {.gcsafe.} = proc disconnect*(peer: Peer, reason: DisconnectionReason, notifyOtherPeer = false) {.async: (raises:[]).} +# TODO Rework the disconnect-and-raise flow to not do both raising +# and disconnection - this results in convoluted control flow and redundant +# disconnect calls template raisePeerDisconnected(msg: string, r: DisconnectionReason) = var e = newException(PeerDisconnected, msg) e.reason = r @@ -216,7 +229,7 @@ proc handshakeImpl[T](peer: Peer, # understanding what error occured where. # And also, incoming and outgoing disconnect errors should be seperated, # probably by seperating the actual disconnect call to begin with. - await disconnectAndRaise(peer, HandshakeTimeout, + await disconnectAndRaise(peer, TcpError, "Protocol handshake was not received in time.") except CatchableError as exc: raise newException(P2PInternalError, exc.msg) @@ -228,23 +241,17 @@ proc `==`(lhs, rhs: Dispatcher): bool = lhs.activeProtocols == rhs.activeProtocols proc describeProtocols(d: Dispatcher): string = - result = "" - for protocol in d.activeProtocols: - if result.len != 0: result.add(',') - for c in protocol.name: result.add(c) + d.activeProtocols.mapIt($it.capability).join(",") proc numProtocols(d: Dispatcher): int = d.activeProtocols.len -proc getDispatcher(node: EthereumNode, - otherPeerCapabilities: openArray[Capability]): Dispatcher = - # TODO: sub-optimal solution until progress is made here: - # https://github.com/nim-lang/Nim/issues/7457 - # We should be able to find an existing dispatcher without allocating a new one - - new result - newSeq(result.protocolOffsets, protocolCount()) - result.protocolOffsets.fill Opt.none(uint64) +proc getDispatcher( + node: EthereumNode, otherPeerCapabilities: openArray[Capability] +): Opt[Dispatcher] = + let dispatcher = Dispatcher() + newSeq(dispatcher.protocolOffsets, protocolCount()) + dispatcher.protocolOffsets.fill Opt.none(uint64) var nextUserMsgId = 0x10u64 @@ -252,9 +259,8 @@ proc getDispatcher(node: EthereumNode, let idx = localProtocol.index block findMatchingProtocol: for remoteCapability in otherPeerCapabilities: - if localProtocol.name == remoteCapability.name and - localProtocol.version == remoteCapability.version: - result.protocolOffsets[idx] = Opt.some(nextUserMsgId) + if localProtocol.capability == remoteCapability: + dispatcher.protocolOffsets[idx] = Opt.some(nextUserMsgId) nextUserMsgId += localProtocol.messages.len.uint64 break findMatchingProtocol @@ -262,15 +268,21 @@ proc getDispatcher(node: EthereumNode, for i in 0 ..< src.len: dest[index + i] = src[i] - result.messages = newSeq[MessageInfo](nextUserMsgId) - devp2pInfo.messages.copyTo(result.messages, 0) + dispatcher.messages = newSeq[MessageInfo](nextUserMsgId) + devp2pInfo.messages.copyTo(dispatcher.messages, 0) for localProtocol in node.protocols: let idx = localProtocol.index - if result.protocolOffsets[idx].isSome: - result.activeProtocols.add localProtocol - localProtocol.messages.copyTo(result.messages, - result.protocolOffsets[idx].value.int) + if dispatcher.protocolOffsets[idx].isSome: + dispatcher.activeProtocols.add localProtocol + localProtocol.messages.copyTo( + dispatcher.messages, dispatcher.protocolOffsets[idx].value.int + ) + + if dispatcher.numProtocols == 0: + Opt.none(Dispatcher) + else: + Opt.some(dispatcher) proc getMsgName*(peer: Peer, msgId: uint64): string = if not peer.dispatcher.isNil and @@ -279,40 +291,26 @@ proc getMsgName*(peer: Peer, msgId: uint64): string = return peer.dispatcher.messages[msgId].name else: return case msgId - of 0: "hello" - of 1: "disconnect" - of 2: "ping" - of 3: "pong" + of msgIdHello: "hello" + of msgIdDisconnect: "disconnect" + of msgIdPing: "ping" + of msgIdPong: "pong" else: $msgId -proc getMsgMetadata*(peer: Peer, msgId: uint64): (ProtocolInfo, MessageInfo) = - doAssert msgId >= 0 - - let dpInfo = devp2pInfo() - if msgId <= dpInfo.messages[^1].id: - return (dpInfo, dpInfo.messages[msgId]) - - if msgId < peer.dispatcher.messages.len.uint64: - let numProtocol = protocolCount() - for i in 0 ..< numProtocol: - let protocol = getProtocol(i) - let offset = peer.dispatcher.protocolOffsets[i] - if offset.isSome and - offset.value + protocol.messages[^1].id >= msgId: - return (protocol, peer.dispatcher.messages[msgId]) - # Protocol info objects # -proc initProtocol(name: string, version: uint64, - peerInit: PeerStateInitializer, - networkInit: NetworkStateInitializer): ProtocolInfo = +proc initProtocol( + name: string, + version: uint64, + peerInit: PeerStateInitializer, + networkInit: NetworkStateInitializer, +): ProtocolInfo = ProtocolInfo( - name : name, - version : version, + capability: Capability(name: name, version: version), messages: @[], peerStateInitializer: peerInit, - networkStateInitializer: networkInit + networkStateInitializer: networkInit, ) proc setEventHandlers(p: ProtocolInfo, @@ -321,12 +319,13 @@ proc setEventHandlers(p: ProtocolInfo, p.handshake = handshake p.disconnectHandler = disconnectHandler -func asCapability*(p: ProtocolInfo): Capability = - result.name = p.name - result.version = p.version - proc cmp*(lhs, rhs: ProtocolInfo): int = - return cmp(lhs.name, rhs.name) + let c = cmp(lhs.capability.name, rhs.capability.name) + if c == 0: + # Highest version first! + -cmp(lhs.capability.version, rhs.capability.version) + else: + c proc nextMsgResolver[MsgType](msgData: Rlp, future: FutureBase) {.gcsafe, raises: [RlpError].} = @@ -393,29 +392,57 @@ template compressMsg(peer: Peer, data: seq[byte]): seq[byte] = when useSnappy: if peer.snappyEnabled: snappy.encode(data) - else: data + else: + data else: data -proc sendMsg*(peer: Peer, data: seq[byte]) {.async.} = - var cipherText = encryptMsg(peer.compressMsg(data), peer.secretsState) +proc recvMsg( + peer: Peer +): Future[tuple[msgId: uint64, msgRlp: Rlp]] {. + async: (raises: [CancelledError, PeerDisconnected]) +.} = try: - var res = await peer.transport.write(cipherText) - if res != len(cipherText): - # This is ECONNRESET or EPIPE case when remote peer disconnected. - await peer.disconnect(TcpError) - discard - except CatchableError as e: - await peer.disconnect(TcpError) - raise e + var msgBody = await peer.transport.recvMsg() + when useSnappy: + if peer.snappyEnabled: + msgBody = snappy.decode(msgBody, maxMsgSize) + if msgBody.len == 0: + await peer.disconnectAndRaise( + BreachOfProtocol, "Snappy uncompress encountered malformed data" + ) + var tmp = rlpFromBytes(msgBody) + let msgId = tmp.read(uint64) + return (msgId, tmp) + except TransportError as exc: + await peer.disconnectAndRaise(TcpError, exc.msg) + except RlpxTransportError as exc: + await peer.disconnectAndRaise(BreachOfProtocol, exc.msg) + except RlpError: + await peer.disconnectAndRaise(BreachOfProtocol, "Could not decode msgId") -proc send*[Msg](peer: Peer, msg: Msg): Future[void] = +proc encodeMsg(msgId: uint64, msg: auto): seq[byte] = + var rlpWriter = initRlpWriter() + rlpWriter.append msgId + rlpWriter.appendRecordType(msg, typeof(msg).rlpFieldsCount > 1) + rlpWriter.finish + +proc sendMsg( + peer: Peer, data: seq[byte] +): Future[void] {.async: (raises: [CancelledError, PeerDisconnected]).} = + try: + await peer.transport.sendMsg(peer.compressMsg(data)) + except TransportError as exc: + await peer.disconnectAndRaise(TcpError, exc.msg) + except RlpxTransportError as exc: + await peer.disconnectAndRaise(BreachOfProtocol, exc.msg) + +proc send*[Msg]( + peer: Peer, msg: Msg +): Future[void] {.async: (raises: [CancelledError, PeerDisconnected], raw: true).} = logSentMsg(peer, msg) - var rlpWriter = initRlpWriter() - rlpWriter.append perPeerMsgId(peer, Msg) - rlpWriter.appendRecordType(msg, Msg.rlpFieldsCount > 1) - peer.sendMsg rlpWriter.finish + peer.sendMsg encodeMsg(perPeerMsgId(peer, Msg), msg) proc registerRequest(peer: Peer, timeout: Duration, @@ -540,70 +567,6 @@ proc resolveResponseFuture(peer: Peer, msgId: uint64, msg: pointer, reqId: uint6 trace "late or dup RPLx reply ignored" -proc recvMsg*(peer: Peer): Future[tuple[msgId: uint64, msgData: Rlp]] {.async.} = - ## This procs awaits the next complete RLPx message in the TCP stream - - var headerBytes: array[32, byte] - await peer.transport.readExactly(addr headerBytes[0], 32) - - var msgHeader: RlpxHeader - let msgSize = decryptHeader( - peer.secretsState, headerBytes, msgHeader).valueOr: - await peer.disconnectAndRaise(BreachOfProtocol, - "Cannot decrypt RLPx frame header") - 0 # TODO raises analysis insufficient - - if msgSize > maxMsgSize: - await peer.disconnectAndRaise(BreachOfProtocol, - "RLPx message exceeds maximum size") - - let remainingBytes = encryptedLength(msgSize) - 32 - var encryptedBytes = newSeq[byte](remainingBytes) - await peer.transport.readExactly(addr encryptedBytes[0], len(encryptedBytes)) - - let decryptedMaxLength = decryptedLength(msgSize) - var - decryptedBytes = newSeq[byte](decryptedMaxLength) - - if decryptBody(peer.secretsState, encryptedBytes, msgSize, - decryptedBytes).isErr(): - await peer.disconnectAndRaise(BreachOfProtocol, - "Cannot decrypt RLPx frame body") - - decryptedBytes.setLen(msgSize) - - when useSnappy: - if peer.snappyEnabled: - decryptedBytes = snappy.decode(decryptedBytes, maxMsgSize) - if decryptedBytes.len == 0: - await peer.disconnectAndRaise(BreachOfProtocol, - "Snappy uncompress encountered malformed data") - - # Check embedded header-data for start of an obsoleted chunked message. - # Note that the check should come *before* the `msgId` is read. For - # instance, if this is a malformed packet, then the `msgId` might be - # random which in turn might try to access a `peer.dispatcher.messages[]` - # slot with a `nil` entry. - # - # The current RLPx requirements need both tuuple entries be zero, see - # github.com/ethereum/devp2p/blob/master/rlpx.md#framing - # - if (msgHeader[4] and 127) != 0 or # capability-id, now required to be zero - (msgHeader[5] and 127) != 0: # context-id, now required to be zero - await peer.disconnectAndRaise( - BreachOfProtocol, "Rejected obsoleted chunked message header") - - var rlp = rlpFromBytes(decryptedBytes) - - var msgId: uint32 - try: - # uint32 as this seems more than big enough for the amount of msgIds - msgId = rlp.read(uint32) - result = (msgId.uint64, rlp) - except RlpError: - await peer.disconnectAndRaise(BreachOfProtocol, - "Cannot read RLPx message id") - proc checkedRlpRead(peer: Peer, r: var Rlp, MsgType: type): auto {.raises: [RlpError].} = @@ -622,32 +585,6 @@ proc checkedRlpRead(peer: Peer, r: var Rlp, MsgType: type): raise e -proc waitSingleMsg(peer: Peer, MsgType: type): Future[MsgType] {.async.} = - let wantedId = peer.perPeerMsgId(MsgType) - while true: - var (nextMsgId, nextMsgData) = await peer.recvMsg() - - if nextMsgId == wantedId: - try: - result = checkedRlpRead(peer, nextMsgData, MsgType) - logReceivedMsg(peer, result) - return - except rlp.RlpError: - await peer.disconnectAndRaise(BreachOfProtocol, - "Invalid RLPx message body") - - elif nextMsgId == 1: # p2p.disconnect - # TODO: can still raise RlpError here...? - let reasonList = nextMsgData.read(DisconnectionReasonList) - let reason = reasonList.value - await peer.disconnect(reason) - trace "disconnect message received in waitSingleMsg", reason, peer - raisePeerDisconnected("Unexpected disconnect", reason) - else: - debug "Dropped RLPX message", - msg = peer.dispatcher.messages[nextMsgId].name - # TODO: This is breach of protocol? - proc nextMsg*(peer: Peer, MsgType: type): Future[MsgType] = ## This procs awaits a specific RLPx message. ## Any messages received while waiting will be dispatched to their @@ -959,17 +896,24 @@ proc p2pProtocolBackendImpl*(protocol: P2PProtocol): Backend = newLit(protocol.version), protocol.peerInit, protocol.netInit) - -p2pProtocol DevP2P(version = 5, rlpxName = "p2p"): - proc hello(peer: Peer, - version: uint64, - clientId: string, - capabilities: seq[Capability], - listenPort: uint, - nodeId: array[RawPublicKeySize, byte]) +# TODO change to version 5 when snappy is enabled +p2pProtocol DevP2P(version = 4, rlpxName = "p2p"): + proc hello( + peer: Peer, + version: uint64, + clientId: string, + capabilities: seq[Capability], + listenPort: uint, + nodeId: array[RawPublicKeySize, byte], + ) = + # The first hello message gets processed during the initial handshake - this + # version is used for any subsequent messages + await peer.disconnect(BreachOfProtocol, true) proc sendDisconnectMsg(peer: Peer, reason: DisconnectionReasonList) = - trace "disconnect message received", reason=reason.value, peer + ## Notify other peer that we're about to disconnect them for the given + ## reason + trace "disconnect message received", reason = reason.value, peer await peer.disconnect(reason.value, false) # Adding an empty RLP list as the spec defines. @@ -1046,19 +990,13 @@ proc disconnect*(peer: Peer, reason: DisconnectionReason, peer.connectionState = Disconnected removePeer(peer.network, peer) -func validatePubKeyInHello(msg: DevP2P.hello, pubKey: PublicKey): bool = - let pk = PublicKey.fromRaw(msg.nodeId) - pk.isOk and pk[] == pubKey - -func checkUselessPeer(peer: Peer) {.raises: [UselessPeerError].} = - if peer.dispatcher.numProtocols == 0: - # XXX: Send disconnect + UselessPeer - raise newException(UselessPeerError, "Useless peer") - -proc initPeerState*(peer: Peer, capabilities: openArray[Capability]) - {.raises: [UselessPeerError].} = - peer.dispatcher = getDispatcher(peer.network, capabilities) - checkUselessPeer(peer) +proc initPeerState*( + peer: Peer, capabilities: openArray[Capability] +) {.raises: [UselessPeerError].} = + peer.dispatcher = getDispatcher(peer.network, capabilities).valueOr: + raise (ref UselessPeerError)( + msg: "No capabilities in common (" & capabilities.mapIt($it).join(",") + ) # The dispatcher has determined our message ID sequence. # For each message ID, we allocate a potential slot for @@ -1075,6 +1013,7 @@ proc initPeerState*(peer: Peer, capabilities: openArray[Capability]) peer.initProtocolStates peer.dispatcher.activeProtocols proc postHelloSteps(peer: Peer, h: DevP2P.hello) {.async.} = + peer.clientId = h.clientId initPeerState(peer, h.capabilities) # Please note that the ordering of operations here is important! @@ -1122,11 +1061,6 @@ proc postHelloSteps(peer: Peer, h: DevP2P.hello) {.async.} = "messageProcessingLoop ended while connecting") peer.connectionState = Connected -template `^`(arr): auto = - # passes a stack array with a matching `arrLen` - # variable as an open array - arr.toOpenArray(0, `arr Len` - 1) - template setSnappySupport(peer: Peer, node: EthereumNode, hello: DevP2P.hello) = when useSnappy: peer.snappyEnabled = node.protocolVersion >= devp2pSnappyVersion.uint64 and @@ -1151,277 +1085,225 @@ type PeerDisconnectedError, TooManyPeersError -proc initiatorHandshake( - node: EthereumNode, transport: StreamTransport, pubkey: PublicKey -): Future[ConnectionSecret] {. - async: (raises: [CancelledError, TransportError, EthP2PError]) -.} = - # https://github.com/ethereum/devp2p/blob/5713591d0366da78a913a811c7502d9ca91d29a8/rlpx.md#initial-handshake - var - handshake = Handshake.init(node.rng[], node.keys, {Initiator}) - authMsg: array[AuthMessageMaxEIP8, byte] +proc helloHandshake( + node: EthereumNode, peer: Peer +): Future[DevP2P.hello] {.async: (raises: [CancelledError, PeerDisconnected]).} = + ## Negotiate common capabilities using the p2p `hello` message + + # https://github.com/ethereum/devp2p/blob/5713591d0366da78a913a811c7502d9ca91d29a8/rlpx.md#hello-0x00 + + await peer.send( + DevP2P.hello( + version: node.baseProtocolVersion(), + clientId: node.clientId, + capabilities: node.capabilities, + listenPort: 0, # obsolete + nodeId: node.keys.pubkey.toRaw(), + ) + ) + + # The first message received must be a hello or a disconnect + var (msgId, msgData) = await peer.recvMsg() + + try: + case msgId + of msgIdHello: + # Implementations must ignore any additional list elements in Hello + # because they may be used by a future version. + let response = msgData.read(DevP2P.hello) + trace "Received Hello", version = response.version, id = response.clientId + + if response.nodeId != peer.transport.pubkey.toRaw: + await peer.disconnectAndRaise( + BreachOfProtocol, "nodeId in hello does not match RLPx transport identity" + ) + + return response + of msgIdDisconnect: # Disconnection requested by peer + # TODO distinguish their reason from ours + let reason = msgData.read(DisconnectionReasonList).value + await peer.disconnectAndRaise( + reason, "Peer disconnecting during hello: " & $reason + ) + else: + # No other messages may be sent until a Hello is received. + await peer.disconnectAndRaise(BreachOfProtocol, "Expected hello, got " & $msgId) + except RlpError: + await peer.disconnectAndRaise(BreachOfProtocol, "Could not decode hello RLP") + +proc rlpxConnect*( + node: EthereumNode, remote: Node +): Future[Result[Peer, RlpxError]] {.async: (raises: [CancelledError]).} = + # TODO move logging elsewhere - the aim is to have exactly _one_ debug log per + # connection attempt (success or failure) to not spam the logs + initTracing(devp2pInfo, node.protocols) + logScope: + remote + trace "Connecting to peer" let - authMsgLen = handshake.authMessage(node.rng[], pubkey, authMsg).expect( - "No errors with correctly sized buffer" - ) + peer = Peer(remote: remote, network: node) + deadline = sleepAsync(connectionTimeout) - writeRes = await transport.write(addr authMsg[0], authMsgLen) - if writeRes != authMsgLen: - raisePeerDisconnected("Unexpected disconnect while authenticating", TcpError) - - var ackMsg = newSeqOfCap[byte](1024) - ackMsg.setLen(MsgLenLenEIP8) - await transport.readExactly(addr ackMsg[0], len(ackMsg)) - - let ackMsgLen = handshake.decodeMsgLen(ackMsg).valueOr: - raise (ref MalformedMessageError)( - msg: "Could not decode handshake ack length: " & $error - ) - - ackMsg.setLen(ackMsgLen) - await transport.readExactly(addr ackMsg[MsgLenLenEIP8], ackMsgLen - MsgLenLenEIP8) - - handshake.decodeAckMessage(ackMsg).isOkOr: - raise (ref MalformedMessageError)(msg: "Could not decode handshake ack: " & $error) - - handshake.getSecrets(^authMsg, ackMsg) - -proc responderHandshake( - node: EthereumNode, transport: StreamTransport -): Future[(ConnectionSecret, PublicKey)] {. - async: (raises: [CancelledError, TransportError, EthP2PError]) -.} = - # https://github.com/ethereum/devp2p/blob/5713591d0366da78a913a811c7502d9ca91d29a8/rlpx.md#initial-handshake - var - handshake = Handshake.init(node.rng[], node.keys, {auth.Responder}) - authMsg = newSeqOfCap[byte](1024) - - authMsg.setLen(MsgLenLenEIP8) - await transport.readExactly(addr authMsg[0], len(authMsg)) - - let authMsgLen = handshake.decodeMsgLen(authMsg).valueOr: - raise (ref MalformedMessageError)( - msg: "Could not decode handshake auth length: " & $error - ) - - authMsg.setLen(authMsgLen) - await transport.readExactly(addr authMsg[MsgLenLenEIP8], authMsgLen - MsgLenLenEIP8) - - handshake.decodeAuthMessage(authMsg).isOkOr: - raise (ref MalformedMessageError)( - msg: "Could not decode handshake auth message: " & $error - ) - - var ackMsg: array[AckMessageMaxEIP8, byte] - let ackMsgLen = handshake.ackMessage(node.rng[], ackMsg).expect( - "no errors with correcly sized buffer" - ) - - var res = await transport.write(addr ackMsg[0], ackMsgLen) - if res != ackMsgLen: - raisePeerDisconnected("Unexpected disconnect while authenticating", TcpError) - - (handshake.getSecrets(authMsg, ^ackMsg), handshake.remoteHPubkey) - -proc rlpxConnect*(node: EthereumNode, remote: Node): - Future[Result[Peer, RlpxError]] {.async.} = - # TODO: Should we not set some timeouts on the `connect` and `readExactly`s? - # Or should we have a general timeout on the whole rlpxConnect where it gets - # called? - # Now, some parts could potential hang until a tcp timeout is hit? - initTracing(devp2pInfo, node.protocols) - - let peer = Peer(remote: remote, network: node) - let ta = initTAddress(remote.node.address.ip, remote.node.address.tcpPort) var error = true defer: + deadline.cancelSoon() # Harmless if finished + if error: # TODO: Not sure if I like this much - if not isNil(peer.transport): - if not peer.transport.closed: - peer.transport.close() + if peer.transport != nil: + peer.transport.close() peer.transport = try: - await connect(ta) - except TransportError: + let ta = initTAddress(remote.node.address.ip, remote.node.address.tcpPort) + await RlpxTransport.connect(node.rng, node.keys, ta, remote.node.pubkey).wait( + deadline + ) + except AsyncTimeoutError: + debug "Connect timeout" return err(TransportConnectError) - except CatchableError as e: - # Aside from TransportOsError, seems raw CatchableError can also occur? - trace "TCP connect with peer failed", err = $e.name, errMsg = $e.msg + except RlpxTransportError as exc: + debug "Connect RlpxTransport error", err = exc.msg + return err(ProtocolError) + except TransportError as exc: + debug "Connect transport error", err = exc.msg return err(TransportConnectError) - try: - let secrets = await node.initiatorHandshake(peer.transport, remote.node.pubkey) - initSecretState(secrets, peer.secretsState) - except TransportError: - return err(RlpxHandshakeTransportError) - except EthP2PError: - return err(RlpxHandshakeError) - except CatchableError as e: - raiseAssert($e.name & " " & $e.msg) - logConnectedPeer peer # RLPx p2p capability handshake: After the initial handshake, both sides of # the connection must send either Hello or a Disconnect message. - let - sendHelloFut = peer.hello( - node.baseProtocolVersion(), - node.clientId, - node.capabilities, - uint(node.address.tcpPort), - node.keys.pubkey.toRaw()) - - receiveHelloFut = peer.waitSingleMsg(DevP2P.hello) - - response = - try: - await peer.handshakeImpl( - sendHelloFut, - receiveHelloFut, - 10.seconds) - except RlpError: - return err(ProtocolError) - except PeerDisconnected: + let response = + try: + await node.helloHandshake(peer).wait(deadline) + except AsyncTimeoutError: + debug "Connect handshake timeout" + return err(P2PHandshakeError) + except PeerDisconnected as exc: + debug "Connect handshake disconneced", err = exc.msg, reason = exc.reason + case exc.reason + of TooManyPeers: + return err(TooManyPeersError) + else: return err(PeerDisconnectedError) - # TODO: Strange compiler error - # case e.reason: - # of HandshakeTimeout: - # # Yeah, a bit odd but in this case PeerDisconnected comes from a - # # timeout on the P2P Hello message. TODO: Clean-up that handshakeImpl - # return err(P2PHandshakeError) - # of TooManyPeers: - # return err(TooManyPeersError) - # else: - # return err(PeerDisconnectedError) - except TransportError: - return err(P2PTransportError) - except P2PInternalError: - return err(P2PHandshakeError) - except CatchableError as e: - raiseAssert($e.name & " " & $e.msg) - - if not validatePubKeyInHello(response, remote.node.pubkey): - trace "Wrong devp2p identity in Hello message" - return err(InvalidIdentityError) peer.setSnappySupport(node, response) - trace "DevP2P handshake completed", peer = remote, + logScope: clientId = response.clientId + trace "DevP2P handshake completed" + try: await postHelloSteps(peer, response) - except RlpError: - return err(ProtocolError) - except PeerDisconnected as e: - case e.reason: + except PeerDisconnected as exc: + debug "Disconnect finishing hello", + remote, clientId = response.clientId, err = exc.msg, reason = exc.reason + case exc.reason of TooManyPeers: return err(TooManyPeersError) else: return err(PeerDisconnectedError) - except UselessPeerError: + except UselessPeerError as exc: + debug "Useless peer finishing hello", err = exc.msg return err(UselessRlpxPeerError) - except TransportError: - return err(P2PTransportError) - except EthP2PError: + except EthP2PError as exc: + debug "P2P error finishing hello", err = exc.msg return err(ProtocolError) except CatchableError as e: + # TODO certainly needs fixing - this could be a cancellation! raiseAssert($e.name & " " & $e.msg) - debug "Peer fully connected", peer = remote, clientId = response.clientId + debug "Peer connected", capabilities = response.capabilities error = false return ok(peer) # TODO: rework rlpxAccept similar to rlpxConnect. -proc rlpxAccept*( - node: EthereumNode, transport: StreamTransport): Future[Peer] {.async: (raises: []).} = +proc rlpxAccept*(node: EthereumNode, stream: StreamTransport): Future[Peer] {.async.} = + # TODO move logging elsewhere - the aim is to have exactly _one_ debug log per + # connection attempt (success or failure) to not spam the logs initTracing(devp2pInfo, node.protocols) - let peer = Peer(transport: transport, network: node) + let + peer = Peer(network: node) + remoteAddress = stream.remoteAddress() + deadline = sleepAsync(connectionTimeout) + trace "Incoming connection", remoteAddress = $remoteAddress + var ok = false try: - let (secrets, pubkey) = await node.responderHandshake(transport) - initSecretState(secrets, peer.secretsState) + peer.transport = + await RlpxTransport.accept(node.rng, node.keys, stream).wait(deadline) - let listenPort = transport.localAddress().port + let + # The ports in this address are not necessarily the ports that the peer is + # actually listening on, so we cannot use this information to connect to + # the peer in the future! + address = Address( + ip: remoteAddress.address, + tcpPort: remoteAddress.port, + udpPort: remoteAddress.port, + ) + + peer.remote = newNode(ENode(pubkey: peer.transport.pubkey, address: address)) logAcceptedPeer peer + logScope: + remote = peer.remote - var sendHelloFut = peer.hello( - node.baseProtocolVersion(), - node.clientId, - node.capabilities, - listenPort.uint, - node.keys.pubkey.toRaw()) - - var response = await peer.handshakeImpl( - sendHelloFut, - peer.waitSingleMsg(DevP2P.hello), - 10.seconds) - - trace "Received Hello", version=response.version, id=response.clientId - - if not validatePubKeyInHello(response, pubkey): - raise (ref MalformedMessageError)(msg: "Wrong pubkey in hello message") + let response = await node.helloHandshake(peer).wait(deadline) peer.setSnappySupport(node, response) - let remote = transport.remoteAddress() - let address = Address(ip: remote.address, tcpPort: remote.port, - udpPort: remote.port) - peer.remote = newNode(ENode(pubkey: pubkey, address: address)) - - trace "devp2p handshake completed", peer = peer.remote, + logScope: clientId = response.clientId + trace "devp2p handshake completed" + # In case there is an outgoing connection started with this peer we give # precedence to that one and we disconnect here with `AlreadyConnected` if peer.remote in node.peerPool.connectedNodes or peer.remote in node.peerPool.connectingNodes: trace "Duplicate connection in rlpxAccept" - raisePeerDisconnected("Peer already connecting or connected", - AlreadyConnected) + raisePeerDisconnected("Peer already connecting or connected", AlreadyConnected) node.peerPool.connectingNodes.incl(peer.remote) await postHelloSteps(peer, response) ok = true - trace "Peer fully connected", peer = peer.remote, clientId = response.clientId - except PeerDisconnected as e: - case e.reason - of AlreadyConnected, TooManyPeers, MessageTimeout: - trace "RLPx disconnect", reason = e.reason, peer = peer.remote - else: - debug "RLPx disconnect unexpected", reason = e.reason, - msg = e.msg, peer = peer.remote + debug "Peer accepted", capabilities = response.capabilities + except PeerDisconnected as exc: + debug "Disconnect while accepting", + remote = peer.remote, clientId = peer.clientId, reason = exc.reason, err = exc.msg - rlpx_accept_failure.inc(labelValues = [$e.reason]) - except TransportIncompleteError: - trace "Connection dropped in rlpxAccept", remote = peer.remote + rlpx_accept_failure.inc(labelValues = [$exc.reason]) + except TransportIncompleteError as exc: + trace "Connection dropped in rlpxAccept", remote = peer.remote, err = exc.msg rlpx_accept_failure.inc(labelValues = [$TransportIncompleteError]) - except UselessPeerError: - trace "Disconnecting useless peer", peer = peer.remote + except UselessPeerError as exc: + debug "Useless peer while accepting", + remote = peer.remote, clientId = peer.clientId, err = exc.msg rlpx_accept_failure.inc(labelValues = [$UselessPeerError]) - except RlpTypeMismatch as e: - # Some peers report capabilities with names longer than 3 chars. We ignore - # those for now. Maybe we should allow this though. - trace "Rlp error in rlpxAccept", err = e.msg, errName = e.name + except RlpTypeMismatch as exc: + debug "Rlp error while accepting", + remote = peer.remote, clientId = peer.clientId, err = exc.msg rlpx_accept_failure.inc(labelValues = [$RlpTypeMismatch]) - except TransportOsError as e: - if e.code == OSErrorCode(110): - trace "RLPx timeout", err = e.msg, errName = e.name + except TransportOsError as exc: + debug "Transport error while accepting", + remote = peer.remote, clientId = peer.clientId, err = exc.msg + if exc.code == OSErrorCode(110): rlpx_accept_failure.inc(labelValues = ["tcp_timeout"]) else: - trace "TransportOsError", err = e.msg, errName = e.name - rlpx_accept_failure.inc(labelValues = [$e.name]) - except CatchableError as e: - trace "RLPx error", err = e.msg, errName = e.name - rlpx_accept_failure.inc(labelValues = [$e.name]) + rlpx_accept_failure.inc(labelValues = [$exc.name]) + except CatchableError as exc: + debug "Error while accepting", + remote = peer.remote, clientId = peer.clientId, err = exc.msg + rlpx_accept_failure.inc(labelValues = [$exc.name]) + + deadline.cancelSoon() # Harmless if finished if not ok: if not isNil(peer.transport): @@ -1432,23 +1314,3 @@ proc rlpxAccept*( else: rlpx_accept_success.inc() return peer - -when isMainModule: - - when false: - # The assignments below can be used to investigate if the RLPx procs - # are considered GcSafe. The short answer is that they aren't, because - # they dispatch into user code that might use the GC. - type - GcSafeDispatchMsg = proc (peer: Peer, msgId: uint64, msgData: var Rlp) - - GcSafeRecvMsg = proc (peer: Peer): - Future[tuple[msgId: uint64, msgData: Rlp]] {.gcsafe.} - - GcSafeAccept = proc (transport: StreamTransport, myKeys: KeyPair): - Future[Peer] {.gcsafe.} - - var - dispatchMsgPtr = invokeThunk - recvMsgPtr: GcSafeRecvMsg = recvMsg - acceptPtr: GcSafeAccept = rlpxAccept diff --git a/eth/p2p/rlpxcrypt.nim b/eth/p2p/rlpxcrypt.nim index b50049f..6bc50e9 100644 --- a/eth/p2p/rlpxcrypt.nim +++ b/eth/p2p/rlpxcrypt.nim @@ -38,7 +38,8 @@ type IncompleteError = "rlpx: data incomplete" IncorrectArgs = "rlpx: incorrect arguments" - RlpxHeader* = array[16, byte] + RlpxEncryptedHeader* = array[RlpHeaderLength + RlpMacLength, byte] + RlpxHeader* = array[RlpHeaderLength, byte] RlpxResult*[T] = Result[T, RlpxError] @@ -159,21 +160,19 @@ proc encryptMsg*(msg: openArray[byte], secrets: var SecretState): seq[byte] = proc getBodySize*(a: RlpxHeader): int = (int(a[0]) shl 16) or (int(a[1]) shl 8) or int(a[2]) -proc decryptHeader*(c: var SecretState, data: openArray[byte], - output: var RlpxHeader): RlpxResult[int] = +proc decryptHeader*(c: var SecretState, data: openArray[byte]): RlpxResult[RlpxHeader] = ## Decrypts header `data` using SecretState `c` context and store ## result into `output`. ## - ## `header` must be exactly `RlpHeaderLength + RlpMacLength` length. - ## `output` must be at least `RlpHeaderLength` length. + ## `header` must be at least `RlpHeaderLength + RlpMacLength` length. + var tmpmac: keccak256 aes: array[RlpHeaderLength, byte] - if len(data) != RlpHeaderLength + RlpMacLength: + if len(data) < RlpHeaderLength + RlpMacLength: return err(IncompleteError) - if len(output) < RlpHeaderLength: - return err(IncorrectArgs) + # mac_secret = self.ingress_mac.digest()[:HEADER_LEN] tmpmac = c.imac var macsec = tmpmac.finish() @@ -187,14 +186,14 @@ proc decryptHeader*(c: var SecretState, data: openArray[byte], tmpmac = c.imac var expectMac = tmpmac.finish() # if not bytes_eq(expected_header_mac, header_mac): - let headerMacPos = RlpHeaderLength - if not equalMem(cast[pointer](unsafeAddr data[headerMacPos]), - cast[pointer](addr expectMac.data[0]), RlpMacLength): - err(IncorrectMac) - else: - # return self.aes_dec.update(header_ciphertext) - c.aesdec.decrypt(toa(data, 0, RlpHeaderLength), output) - ok(output.getBodySize()) + if not equalMem(unsafeAddr data[RlpHeaderLength], + addr expectMac.data[0], RlpMacLength): + return err(IncorrectMac) + + # return self.aes_dec.update(header_ciphertext) + var output: RlpxHeader + c.aesdec.decrypt(toa(data, 0, RlpHeaderLength), output) + ok(output) proc decryptBody*(c: var SecretState, data: openArray[byte], bodysize: int, output: var openArray[byte]): RlpxResult[void] = diff --git a/eth/p2p/rlpxtransport.nim b/eth/p2p/rlpxtransport.nim new file mode 100644 index 0000000..e42d305 --- /dev/null +++ b/eth/p2p/rlpxtransport.nim @@ -0,0 +1,243 @@ +# nim-eth +# Copyright (c) 2018-2024 Status Research & Development GmbH +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at +# https://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at +# https://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except +# according to those terms. + +{.push raises: [], gcsafe.} + +import results, chronos, ../common/keys, ./[auth, rlpxcrypt] + +export results, keys + +type + RlpxTransport* = ref object + stream: StreamTransport + state: SecretState + pubkey*: PublicKey + + RlpxTransportError* = object of CatchableError + +template `^`(arr): auto = + # passes a stack array with a matching `arrLen` variable as an open array + arr.toOpenArray(0, `arr Len` - 1) + +proc initiatorHandshake( + rng: ref HmacDrbgContext, + keys: KeyPair, + stream: StreamTransport, + remotePubkey: PublicKey, +): Future[ConnectionSecret] {. + async: (raises: [CancelledError, TransportError, RlpxTransportError]) +.} = + # https://github.com/ethereum/devp2p/blob/5713591d0366da78a913a811c7502d9ca91d29a8/rlpx.md#initial-handshake + var + handshake = Handshake.init(rng[], keys, {Initiator}) + authMsg: array[AuthMessageMaxEIP8, byte] + + let + authMsgLen = handshake.authMessage(rng[], remotePubkey, authMsg).expect( + "No errors with correctly sized buffer" + ) + + writeRes = await stream.write(addr authMsg[0], authMsgLen) + if writeRes != authMsgLen: + raise (ref RlpxTransportError)(msg: "Could not write RLPx handshake header") + + var ackMsg = newSeqOfCap[byte](1024) + ackMsg.setLen(MsgLenLenEIP8) + await stream.readExactly(addr ackMsg[0], len(ackMsg)) + + let ackMsgLen = handshake.decodeAckMsgLen(ackMsg).valueOr: + raise + (ref RlpxTransportError)(msg: "Could not decode handshake ack length: " & $error) + + ackMsg.setLen(ackMsgLen) + await stream.readExactly(addr ackMsg[MsgLenLenEIP8], ackMsgLen - MsgLenLenEIP8) + + handshake.decodeAckMessage(ackMsg).isOkOr: + raise (ref RlpxTransportError)(msg: "Could not decode handshake ack: " & $error) + + handshake.getSecrets(^authMsg, ackMsg) + +proc responderHandshake( + rng: ref HmacDrbgContext, keys: KeyPair, stream: StreamTransport +): Future[(ConnectionSecret, PublicKey)] {. + async: (raises: [CancelledError, TransportError, RlpxTransportError]) +.} = + # https://github.com/ethereum/devp2p/blob/5713591d0366da78a913a811c7502d9ca91d29a8/rlpx.md#initial-handshake + var + handshake = Handshake.init(rng[], keys, {auth.Responder}) + authMsg = newSeqOfCap[byte](1024) + + authMsg.setLen(MsgLenLenEIP8) + await stream.readExactly(addr authMsg[0], len(authMsg)) + + let authMsgLen = handshake.decodeAuthMsgLen(authMsg).valueOr: + raise + (ref RlpxTransportError)(msg: "Could not decode handshake auth length: " & $error) + + authMsg.setLen(authMsgLen) + await stream.readExactly(addr authMsg[MsgLenLenEIP8], authMsgLen - MsgLenLenEIP8) + + handshake.decodeAuthMessage(authMsg).isOkOr: + raise (ref RlpxTransportError)( + msg: "Could not decode handshake auth message: " & $error + ) + + var ackMsg: array[AckMessageMaxEIP8, byte] + let ackMsgLen = + handshake.ackMessage(rng[], ackMsg).expect("no errors with correcly sized buffer") + + var res = await stream.write(addr ackMsg[0], ackMsgLen) + if res != ackMsgLen: + raise (ref RlpxTransportError)(msg: "Could not write RLPx ack message") + + (handshake.getSecrets(authMsg, ^ackMsg), handshake.remoteHPubkey) + +proc connect*( + _: type RlpxTransport, + rng: ref HmacDrbgContext, + keys: KeyPair, + address: TransportAddress, + remotePubkey: PublicKey, +): Future[RlpxTransport] {. + async: (raises: [CancelledError, TransportError, RlpxTransportError]) +.} = + var stream = await connect(address) + + try: + let secrets = await initiatorHandshake(rng, keys, stream, remotePubkey) + var res = RlpxTransport(stream: move(stream), pubkey: remotePubkey) + initSecretState(secrets, res.state) + res + finally: + if stream != nil: + stream.close() + +proc accept*( + _: type RlpxTransport, + rng: ref HmacDrbgContext, + keys: KeyPair, + stream: StreamTransport, +): Future[RlpxTransport] {. + async: (raises: [CancelledError, TransportError, RlpxTransportError]) +.} = + var stream = stream + try: + let (secrets, remotePubkey) = await responderHandshake(rng, keys, stream) + var res = RlpxTransport(stream: move(stream), pubkey: remotePubkey) + initSecretState(secrets, res.state) + res + finally: + if stream != nil: + stream.close() + +proc recvMsg*( + transport: RlpxTransport +): Future[seq[byte]] {. + async: (raises: [CancelledError, TransportError, RlpxTransportError]) +.} = + ## Read an RLPx frame from the given peer + var msgHeaderEnc: RlpxEncryptedHeader + await transport.stream.readExactly(addr msgHeaderEnc[0], msgHeaderEnc.len) + + let msgHeader = decryptHeader(transport.state, msgHeaderEnc).valueOr: + raise (ref RlpxTransportError)(msg: "Cannot decrypt RLPx frame header") + + # The capability-id and context id are always zero + # https://github.com/ethereum/devp2p/blob/5713591d0366da78a913a811c7502d9ca91d29a8/rlpx.md#framing + if (msgHeader[4] != 0x80) or (msgHeader[5] != 0x80): + raise + (ref RlpxTransportError)(msg: "Invalid capability-id/context-id in RLPx header") + + let msgSize = msgHeader.getBodySize() + let remainingBytes = encryptedLength(msgSize) - 32 + + var encryptedBytes = newSeq[byte](remainingBytes) + await transport.stream.readExactly(addr encryptedBytes[0], len(encryptedBytes)) + + let decryptedMaxLength = decryptedLength(msgSize) # Padded length + var msgBody = newSeq[byte](decryptedMaxLength) + + if decryptBody(transport.state, encryptedBytes, msgSize, msgBody).isErr(): + raise (ref RlpxTransportError)(msg: "Cannot decrypt message body") + + reset(encryptedBytes) # Release memory (TODO: in-place decryption) + + msgBody.setLen(msgSize) # Remove padding + + msgBody + +proc sendMsg*( + transport: RlpxTransport, data: seq[byte] +) {.async: (raises: [CancelledError, TransportError, RlpxTransportError]).} = + let cipherText = encryptMsg(data, transport.state) + var res = await transport.stream.write(cipherText) + if res != len(cipherText): + raise (ref RlpxTransportError)(msg: "Could not complete writing message") + +proc remoteAddress*( + transport: RlpxTransport +): TransportAddress {.raises: [TransportOsError].} = + transport.stream.remoteAddress() + +proc closed*(transport: RlpxTransport): bool = + transport.stream != nil and transport.stream.closed + +proc close*(transport: RlpxTransport) = + if transport.stream != nil: + transport.stream.close() + +proc closeWait*( + transport: RlpxTransport +): Future[void] {.async: (raises: [], raw: true).} = + transport.stream.closeWait() + +when isMainModule: + # Simple CLI application for negotiating an RLPx connection with a peer + + import stew/byteutils, std/cmdline, std/strutils, eth/rlp + if paramCount() < 3: + echo "rlpxtransport ip port pubkey" + quit 1 + + let + rng = newRng() + kp = KeyPair.random(rng[]) + + echo "Local key: ", toHex(kp.pubkey.toRaw()) + + let client = waitFor RlpxTransport.connect( + rng, + kp, + initTAddress(paramStr(1), parseInt(paramStr(2))), + PublicKey.fromHex(paramStr(3))[], + ) + + proc encodeMsg(msgId: uint64, msg: auto): seq[byte] = + var rlpWriter = initRlpWriter() + rlpWriter.append msgId + rlpWriter.appendRecordType(msg, typeof(msg).rlpFieldsCount > 1) + rlpWriter.finish + + waitFor client.sendMsg( + encodeMsg( + uint64 0, (uint64 4, "nimbus", @[("eth", uint64 68)], uint64 0, kp.pubkey.toRaw()) + ) + ) + + while true: + echo "Reading message" + var data = waitFor client.recvMsg() + var rlp = rlpFromBytes(data) + let msgId = rlp.read(uint64) + if msgId == 0: + echo "Hello: ", + rlp.read((uint64, string, seq[(string, uint64)], uint64, seq[byte])) + else: + echo "Unknown message ", msgId, " ", toHex(data) diff --git a/tests/p2p/all_tests.nim b/tests/p2p/all_tests.nim index 66b2193..f1261ec 100644 --- a/tests/p2p/all_tests.nim +++ b/tests/p2p/all_tests.nim @@ -6,4 +6,5 @@ import ./test_ecies, ./test_enode, ./test_rlpx_thunk, + ./test_rlpxtransport, ./test_protocol_handlers \ No newline at end of file diff --git a/tests/p2p/test_crypt.nim b/tests/p2p/test_crypt.nim index 6ec7b1b..46382eb 100644 --- a/tests/p2p/test_crypt.nim +++ b/tests/p2p/test_crypt.nim @@ -159,11 +159,9 @@ suite "Ethereum RLPx encryption/decryption test suite": var csecResponder = responder.getSecrets(m0, m1) var stateInitiator: SecretState var stateResponder: SecretState - var iheader, rheader: array[16, byte] + var iheader: array[16, byte] initSecretState(csecInitiator, stateInitiator) initSecretState(csecResponder, stateResponder) - burnMem(iheader) - burnMem(rheader) for i in 1..1000: # initiator -> responder block: @@ -176,8 +174,9 @@ suite "Ethereum RLPx encryption/decryption test suite": randomBytes(ibody) == len(ibody) stateInitiator.encrypt(iheader, ibody, encrypted).isOk() - stateResponder.decryptHeader(toOpenArray(encrypted, 0, 31), - rheader).isOk() + let rheader = stateResponder.decryptHeader( + toOpenArray(encrypted, 0, 31)).expect("valid data") + var length = getBodySize(rheader) check length == len(ibody) var rbody = newSeq[byte](decryptedLength(length)) @@ -190,7 +189,6 @@ suite "Ethereum RLPx encryption/decryption test suite": iheader == rheader ibody == rbody burnMem(iheader) - burnMem(rheader) # responder -> initiator block: var ibody = newSeq[byte](i * 3) @@ -202,8 +200,8 @@ suite "Ethereum RLPx encryption/decryption test suite": randomBytes(ibody) == len(ibody) stateResponder.encrypt(iheader, ibody, encrypted).isOk() - stateInitiator.decryptHeader(toOpenArray(encrypted, 0, 31), - rheader).isOk() + let rheader = stateInitiator.decryptHeader( + toOpenArray(encrypted, 0, 31)).expect("valid data") var length = getBodySize(rheader) check length == len(ibody) var rbody = newSeq[byte](decryptedLength(length)) @@ -216,4 +214,3 @@ suite "Ethereum RLPx encryption/decryption test suite": iheader == rheader ibody == rbody burnMem(iheader) - burnMem(rheader) diff --git a/tests/p2p/test_rlpxtransport.nim b/tests/p2p/test_rlpxtransport.nim new file mode 100644 index 0000000..733167a --- /dev/null +++ b/tests/p2p/test_rlpxtransport.nim @@ -0,0 +1,60 @@ +{.used.} + +import + unittest2, + chronos/unittest2/asynctests, + ../../eth/common/keys, + ../../eth/p2p/rlpxtransport + +suite "RLPx transport": + setup: + let + rng = newRng() + keys1 = KeyPair.random(rng[]) + keys2 = KeyPair.random(rng[]) + server = createStreamServer(initTAddress("127.0.0.1:0"), {ReuseAddr}) + + teardown: + waitFor server.closeWait() + + asyncTest "Connect/accept": + const msg = @[byte 0, 1, 2, 3] + proc serveClient(server: StreamServer) {.async.} = + let transp = await server.accept() + let a = await RlpxTransport.accept(rng, keys1, transp) + await a.sendMsg(msg) + await a.closeWait() + + let serverFut = server.serveClient() + defer: + await serverFut.wait(1.seconds) + + let client = + await RlpxTransport.connect(rng, keys2, server.localAddress(), keys1.pubkey) + + defer: + await client.closeWait() + let rmsg = await client.recvMsg().wait(1.seconds) + + check: + msg == rmsg + + await serverFut + + asyncTest "Detect invalid pubkey": + proc serveClient(server: StreamServer) {.async.} = + let transp = await server.accept() + discard await RlpxTransport.accept(rng, keys1, transp) + raiseAssert "should fail to accept due to pubkey error" + + let serverFut = server.serveClient() + defer: + expect(RlpxTransportError): + await serverFut.wait(1.seconds) + + let keys3 = KeyPair.random(rng[]) + + # accept side should close connections + expect(TransportError): + discard + await RlpxTransport.connect(rng, keys2, server.localAddress(), keys3.pubkey) diff --git a/tests/rlp/test_api_usage.nim b/tests/rlp/test_api_usage.nim index 80fb998..cf9ad2f 100644 --- a/tests/rlp/test_api_usage.nim +++ b/tests/rlp/test_api_usage.nim @@ -21,8 +21,7 @@ proc test_blockBodyTranscode() = transactions: @[ Transaction(nonce: 1)]), BlockBody( - uncles: @[ - BlockHeader(nonce: BlockNonce([0x20u8,0,0,0,0,0,0,0]))]), + uncles: @[Header(nonce: Bytes8([0x20u8,0,0,0,0,0,0,0]))]), BlockBody(), BlockBody( transactions: @[