diff --git a/eth/p2p/discoveryv5/encoding.nim b/eth/p2p/discoveryv5/encoding.nim index d360038..8c2577f 100644 --- a/eth/p2p/discoveryv5/encoding.nim +++ b/eth/p2p/discoveryv5/encoding.nim @@ -158,31 +158,42 @@ proc decryptGCM(key: array[16, byte], nonce, ct, authData: openarray[byte]): seq result = @[] dctx.clear() -proc decodePacketBody(typ: byte, body: openarray[byte], res: var Packet): bool = - if typ >= PacketKind.low.byte and typ <= PacketKind.high.byte: - let kind = cast[PacketKind](typ) - res = Packet(kind: kind) - var rlp = rlpFromBytes(@body.toRange) - rlp.enterList() +type + DecodePacketResult = enum + decodingSuccessful + invalidPacketPayload + invalidPacketType + unsupportedPacketType + +proc decodePacketBody(typ: byte, + body: openarray[byte], + res: var Packet): DecodePacketResult = + if typ < PacketKind.low.byte or typ > PacketKind.high.byte: + return invalidPacketType + + let kind = cast[PacketKind](typ) + res = Packet(kind: kind) + var rlp = rlpFromBytes(@body.toRange) + if rlp.enterList: res.reqId = rlp.read(RequestId) proc decode[T](rlp: var Rlp, v: var T) {.inline, nimcall.} = for k, v in v.fieldPairs: v = rlp.read(typeof(v)) - template decode(k: untyped) = - if k == kind: - decode(rlp, res.k) - result = true + case kind + of unused: return invalidPacketPayload + of ping: rlp.decode(res.ping) + of pong: rlp.decode(res.pong) + of findNode: rlp.decode(res.findNode) + of nodes: rlp.decode(res.nodes) + of regtopic, ticket, regconfirmation, topicquery: + # TODO Implement these packet types + return unsupportedPacketType - decode(ping) - decode(pong) - decode(findNode) - decode(nodes) + return decodingSuccessful else: - debug "unknown packet type ", typ - - return true + return invalidPacketPayload proc decodeAuthResp(c: Codec, fromId: NodeId, head: AuthHeader, challenge: Whoareyou, secrets: var HandshakeSecrets, newNode: var Node): bool = @@ -215,6 +226,8 @@ proc decodeEncrypted*(c: var Codec, var r = rlpFromBytes(input[32 .. ^1]) var auth: AuthHeader var readKey: array[16, byte] + logScope: sender = $fromAddr + if r.isList: # Handshake - rlp list indicates auth-header @@ -258,7 +271,12 @@ proc decodeEncrypted*(c: var Codec, let body = decryptGCM(readKey, auth.auth, bodyEnc.toOpenArray, input[0 .. 31].toOpenArray) if body.len > 1: - result = decodePacketBody(body[0], body.toOpenArray(1, body.high), packet) + let status = decodePacketBody(body[0], body.toOpenArray(1, body.high), packet) + if status == decodingSuccessful: + return true + else: + debug "Failed to decode discovery packet", reason = status + return false proc newRequestId*(): RequestId = if randomBytes(addr result, sizeof(result)) != sizeof(result): diff --git a/eth/p2p/discoveryv5/enr.nim b/eth/p2p/discoveryv5/enr.nim index 6fc7b8a..0f3c5a5 100644 --- a/eth/p2p/discoveryv5/enr.nim +++ b/eth/p2p/discoveryv5/enr.nim @@ -195,7 +195,8 @@ proc verifySignatureV4(r: Record, sigData: openarray[byte], content: seq[byte]): proc verifySignature(r: Record): bool = var rlp = rlpFromBytes(r.raw.toRange) let sz = rlp.listLen - rlp.enterList() + if not rlp.enterList: + return false let sigData = rlp.read(Bytes) let content = block: var writer = initRlpList(sz - 1) @@ -219,12 +220,16 @@ proc fromBytesAux(r: var Record): bool = return false var rlp = rlpFromBytes(r.raw.toRange) + if not rlp.isList: + return false + let sz = rlp.listLen if sz < minRlpListLen or sz mod 2 != 0: # Wrong rlp object return false - rlp.enterList() + # We already know we are working with a list + discard rlp.enterList() rlp.skipElem() # Skip signature r.seqNum = rlp.read(uint64) diff --git a/eth/p2p/rlpx.nim b/eth/p2p/rlpx.nim index 5d842dc..ec8d405 100644 --- a/eth/p2p/rlpx.nim +++ b/eth/p2p/rlpx.nim @@ -567,7 +567,7 @@ proc p2pProtocolBackendImpl*(protocol: P2PProtocol): Backend = read = bindSym("read", brForceOpen) checkedRlpRead = bindSym "checkedRlpRead" startList = bindSym "startList" - enterList = bindSym "enterList" + tryEnterList = bindSym "tryEnterList" finish = bindSym "finish" messagePrinter = bindSym "messagePrinter" @@ -677,7 +677,7 @@ proc p2pProtocolBackendImpl*(protocol: P2PProtocol): Backend = let paramCount = paramsToWrite.len - readParamsPrelude = if paramCount > 1: newCall(enterList, receivedRlp) + readParamsPrelude = if paramCount > 1: newCall(tryEnterList, receivedRlp) else: newStmtList() when tracingEnabled: diff --git a/eth/p2p/rlpx_protocols/les/flow_control.nim b/eth/p2p/rlpx_protocols/les/flow_control.nim index 92ad5cf..ec0c3a3 100644 --- a/eth/p2p/rlpx_protocols/les/flow_control.nim +++ b/eth/p2p/rlpx_protocols/les/flow_control.nim @@ -104,16 +104,18 @@ proc loadMessageStats*(network: LesNetwork, try: var statsRlp = rlpFromBytes(stats.toRange) - statsRlp.enterList + if not statsRlp.enterList: + notice "Found a corrupted LES stats record" + break readFromDB let version = statsRlp.read(int) if version != lesStatsVer: - notice "Found outdated LES stats record" + notice "Found an outdated LES stats record" break readFromDB statsRlp >> network.messageStats if network.messageStats.len <= les.messages[^1].id: - notice "Found incomplete LES stats record" + notice "Found an incomplete LES stats record" break readFromDB return true diff --git a/eth/p2p/rlpx_protocols/waku_protocol.nim b/eth/p2p/rlpx_protocols/waku_protocol.nim index e2aa669..4169619 100644 --- a/eth/p2p/rlpx_protocols/waku_protocol.nim +++ b/eth/p2p/rlpx_protocols/waku_protocol.nim @@ -147,10 +147,18 @@ proc append*(rlpWriter: var RlpWriter, value: StatusOptions) = rlpWriter.append(rlpFromBytes(bytes.toRange)) proc read*(rlp: var Rlp, T: typedesc[StatusOptions]): T = + if not rlp.isList(): + raise newException(RlpTypeMismatch, + "List expected, but the source RLP is not a list.") + let sz = rlp.listLen() - rlp.enterList() + # We already know that we are working with a list + discard rlp.enterList() for i in 0 ..< sz: - rlp.enterList() + if not rlp.enterList(): + raise newException(RlpTypeMismatch, + "List expected, but the source RLP is not a list.") + var k: KeyKind try: k = rlp.read(KeyKind) diff --git a/eth/rlp.nim b/eth/rlp.nim index 627fa7a..13a524d 100644 --- a/eth/rlp.nim +++ b/eth/rlp.nim @@ -242,10 +242,16 @@ proc currentElemEnd*(self: Rlp): int = elif isBlob() or isList(): result += payloadOffset() + payloadBytesCount() -proc enterList*(self: var Rlp) = +proc enterList*(self: var Rlp): bool = if not isList(): - raise newException(RlpTypeMismatch, "List expected, but source RLP is not a list") + return false + position += payloadOffset() + return true + +proc tryEnterList*(self: var Rlp) = + if not enterList(): + raise newException(RlpTypeMismatch, "List expected, but source RLP is not a list") proc skipElem*(rlp: var Rlp) = rlp.position = rlp.currentElemEnd diff --git a/eth/trie/hexary.nim b/eth/trie/hexary.nim index 6811171..60e6576 100644 --- a/eth/trie/hexary.nim +++ b/eth/trie/hexary.nim @@ -386,7 +386,8 @@ proc replaceValue(data: Rlp, key: NibblesRange, value: BytesRange): Bytes = # XXX: This can be optimized to a direct bitwise copy of the source RLP var iter = data - iter.enterList() + # We already know that we are working with a list + discard iter.enterList() for i in 0 ..< 16: r.append iter iter.skipElem @@ -511,7 +512,8 @@ proc deleteAt(self: var HexaryTrie; else: var rlpRes = initRlpList(17) var iter = origRlp - iter.enterList + # We already know that we are working with a list + discard iter.enterList for i in 0 ..< 16: rlpRes.append iter iter.skipElem diff --git a/tests/rlp/test_api_usage.nim b/tests/rlp/test_api_usage.nim index 7eadf24..2b33868 100644 --- a/tests/rlp/test_api_usage.nim +++ b/tests/rlp/test_api_usage.nim @@ -102,7 +102,7 @@ test "encode and decode lists": var list = rlpFromBytes encodeList(rlp.listELem(1), rlp.listELem(0)).toRange # test that iteration with enterList/skipElem works as expected - list.enterList + discard list.enterList # We alreay know that we are working with a list check list.toString == "Lorem ipsum dolor sit amet" list.skipElem