diff --git a/.appveyor.yml b/.appveyor.yml index d8504f7..13ad96b 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -77,6 +77,7 @@ build_script: test_script: - nimble test - nimble build_dcli + - nimble build_portalcli deploy: off diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0c71200..a1b6034 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -237,4 +237,4 @@ jobs: nimble install -y --depsOnly nimble test nimble build_dcli - + nimble build_portalcli diff --git a/.travis.yml b/.travis.yml index 2a3ad32..b41d868 100644 --- a/.travis.yml +++ b/.travis.yml @@ -47,3 +47,4 @@ script: - nimble install -y --depsOnly - nimble test - nimble build_dcli + - nimble build_portalcli diff --git a/eth.nimble b/eth.nimble index d962cdd..e807552 100644 --- a/eth.nimble +++ b/eth.nimble @@ -46,6 +46,9 @@ task test_discv5, "Run discovery v5 tests": task test_discv4, "Run discovery v4 tests": runTest("tests/p2p/test_discovery") +task test_portal, "Run Portal network tests": + runTest("tests/p2p/all_portal_tests") + task test_p2p, "Run p2p tests": runTest("tests/p2p/all_tests") @@ -86,3 +89,6 @@ task test_discv5_full, "Run discovery v5 and its dependencies tests": task build_dcli, "Build dcli": buildBinary("eth/p2p/discoveryv5/dcli") + +task build_portalcli, "Build portalcli": + buildBinary("eth/p2p/portal/portalcli") diff --git a/eth/p2p/discoveryv5/protocol.nim b/eth/p2p/discoveryv5/protocol.nim index fca9350..aa0366c 100644 --- a/eth/p2p/discoveryv5/protocol.nim +++ b/eth/p2p/discoveryv5/protocol.nim @@ -132,16 +132,20 @@ type bootstrapRecords*: seq[Record] ipVote: IpVote enrAutoUpdate: bool - talkProtocols: Table[seq[byte], TalkProtocolHandler] + talkProtocols*: Table[seq[byte], TalkProtocol] # TODO: Table is a bit of + # overkill here, use sequence rng*: ref BrHmacDrbgContext PendingRequest = object node: Node message: seq[byte] - TalkProtocolHandler* = proc(request: seq[byte]): seq[byte] + TalkProtocolHandler* = proc(p: TalkProtocol, request: seq[byte]): seq[byte] {.gcsafe, raises: [Defect].} + TalkProtocol* = ref object of RootObj + protocolHandler*: TalkProtocolHandler + DiscResult*[T] = Result[T, cstring] proc addNode*(d: Protocol, node: Node): bool = @@ -299,15 +303,16 @@ proc handleFindNode(d: Protocol, fromId: NodeId, fromAddr: Address, proc handleTalkReq(d: Protocol, fromId: NodeId, fromAddr: Address, talkreq: TalkReqMessage, reqId: RequestId) = - let protocolHandler = d.talkProtocols.getOrDefault(talkreq.protocol) + let talkProtocol = d.talkProtocols.getOrDefault(talkreq.protocol) let talkresp = - if protocolHandler.isNil(): + if talkProtocol.isNil() or talkProtocol.protocolHandler.isNil(): # Protocol identifier that is not registered and thus not supported. An # empty response is send as per specification. TalkRespMessage(response: @[]) else: - TalkRespMessage(response: protocolHandler(talkreq.request)) + TalkRespMessage(response: talkProtocol.protocolHandler(talkProtocol, + talkreq.request)) let (data, _) = encodeMessagePacket(d.rng[], d.codec, fromId, fromAddr, encodeMessage(talkresp, reqId)) @@ -341,10 +346,10 @@ proc handleMessage(d: Protocol, srcId: NodeId, fromAddr: Address, trace "Timed out or unrequested message", kind = message.kind, origin = fromAddr -proc registerTalkProtocol*(d: Protocol, protocol: seq[byte], - handler: TalkProtocolHandler): DiscResult[void] = +proc registerTalkProtocol*(d: Protocol, protocolId: seq[byte], + protocol: TalkProtocol): DiscResult[void] = # Currently allow only for one handler per talk protocol. - if d.talkProtocols.hasKeyOrPut(protocol, handler): + if d.talkProtocols.hasKeyOrPut(protocolId, protocol): err("Protocol identifier already registered") else: ok() diff --git a/eth/p2p/portal/README.md b/eth/p2p/portal/README.md new file mode 100644 index 0000000..97ab88f --- /dev/null +++ b/eth/p2p/portal/README.md @@ -0,0 +1,48 @@ +# Portal Network Wire Protocol +## Introduction +The `eth/p2p/portal` directory holds a Nim implementation of the +[Portal Network Wire Protocol](https://github.com/ethereum/stateless-ethereum-specs/blob/master/state-network.md#wire-protocol). + +Both specification, at above link, and implementations are still WIP. + +The protocol builds on top of the Node Discovery v5.1 protocol its `talkreq` and +`talkresp` messages. + +For further information on the Nim implementation of the Node Discovery v5.1 +protocol check out the [discv5](../../../doc/discv5.md) page. + +## Test suite +To run the test suite specifically for the Portal wire protocol, run following +command: +```sh +# Install required modules +nimble install +# Run only Portal tests +nimble test_portal +``` + +## portalcli +This is a small command line application that allows you to run a +Discovery v5.1 + Portal node. + +*Note:* Its objective is only to test the protocol wire component, not to actually +serve content. This means it will always return empty lists on content requests. +Perhaps in the future some hardcoded data could added and maybe some test vectors +can be created in such form. + +The `portalcli` application allows you to either run a node, or to specifically +send one of the message types, wait for the response, and then shut down. + +### Example usage +```sh +# Install required modules +# Make sure you have the latest modules, do NOT trust nimble on this. +nimble install +# Build portalcli +nimble build_portalcli +# See all options +./eth/p2p/portal/portalcli --help +# Example command: Ping another node +./eth/p2p/portal/portalcli ping enr: +# Example command: Run discovery + portal node +./eth/p2p/portal/portalcli --log-level:debug --bootnode:enr: diff --git a/eth/p2p/portal/messages.nim b/eth/p2p/portal/messages.nim new file mode 100644 index 0000000..c0cb179 --- /dev/null +++ b/eth/p2p/portal/messages.nim @@ -0,0 +1,153 @@ +# nim-eth - Portal Network- Message types +# Copyright (c) 2021 Status Research & Development GmbH +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +# As per spec: +# https://github.com/ethereum/stateless-ethereum-specs/blob/master/state-network.md#wire-protocol + +{.push raises: [Defect].} + +import + stint, stew/[results, objects], + ../../ssz/ssz_serialization + +export ssz_serialization, stint + +type + ByteList* = List[byte, 2048] + + MessageKind* = enum + unused = 0x00 + + ping = 0x01 + pong = 0x02 + findnode = 0x03 + nodes = 0x04 + findcontent = 0x05 + foundcontent = 0x06 + advertise = 0x07 + requestproofs = 0x08 + + PingMessage* = object + enrSeq*: uint64 + dataRadius*: UInt256 + + PongMessage* = object + enrSeq*: uint64 + dataRadius*: UInt256 + + FindNodeMessage* = object + distances*: List[uint16, 256] + + NodesMessage* = object + total*: uint8 + enrs*: List[ByteList, 32] # ByteList here is the rlp encoded ENR. This could + # also be limited to 300 bytes instead of 2048 + + FindContentMessage* = object + contentKey*: ByteList + + FoundContentMessage* = object + enrs*: List[ByteList, 32] + payload*: ByteList + + AdvertiseMessage* = List[ByteList, 32] # No container, heh... + + # This would be more consistent with the other messages + # AdvertiseMessage* = object + # contentKeys*: List[ByteList, 32] + + RequestProofsMessage* = object + connectionId*: List[byte, 4] + contentKeys*: List[ByteList, 32] + + Message* = object + case kind*: MessageKind + of ping: + ping*: PingMessage + of pong: + pong*: PongMessage + of findnode: + findNode*: FindNodeMessage + of nodes: + nodes*: NodesMessage + of findcontent: + findcontent*: FindContentMessage + of foundcontent: + foundcontent*: FoundContentMessage + of advertise: + advertise*: AdvertiseMessage + of requestproofs: + requestproofs*: RequestProofsMessage + else: + discard + + SomeMessage* = + PingMessage or PongMessage or + FindNodeMessage or NodesMessage or + FindContentMessage or FoundContentMessage or + AdvertiseMessage or RequestProofsMessage + +template messageKind*(T: typedesc[SomeMessage]): MessageKind = + when T is PingMessage: ping + elif T is PongMessage: pong + elif T is FindNodeMessage: findNode + elif T is NodesMessage: nodes + elif T is FindContentMessage: findcontent + elif T is FoundContentMessage: foundcontent + elif T is AdvertiseMessage: advertise + elif T is RequestProofsMessage: requestproofs + +template toSszType*(x: auto): auto = + mixin toSszType + + when x is UInt256: toBytesLE(x) + else: x + +func fromSszBytes*(T: type UInt256, data: openArray[byte]): + T {.raises: [MalformedSszError, Defect].} = + if data.len != sizeof(result): + raiseIncorrectSize T + + T.fromBytesLE(data) + +proc encodeMessage*[T: SomeMessage](m: T): seq[byte] = + ord(messageKind(T)).byte & SSZ.encode(m) + +proc decodeMessage*(body: openarray[byte]): Result[Message, cstring] = + # Decodes to the specific `Message` type. + if body.len < 1: + return err("No message data") + + var kind: MessageKind + if not checkedEnumAssign(kind, body[0]): + return err("Invalid message type") + + var message = Message(kind: kind) + + try: + case kind + of unused: return err("Invalid message type") + of ping: + message.ping = SSZ.decode(body.toOpenArray(1, body.high), PingMessage) + of pong: + message.pong = SSZ.decode(body.toOpenArray(1, body.high), PongMessage) + of findNode: + message.findNode = SSZ.decode(body.toOpenArray(1, body.high), FindNodeMessage) + of nodes: + message.nodes = SSZ.decode(body.toOpenArray(1, body.high), NodesMessage) + of findcontent: + message.findcontent = SSZ.decode(body.toOpenArray(1, body.high), FindContentMessage) + of foundcontent: + message.foundcontent = SSZ.decode(body.toOpenArray(1, body.high), FoundContentMessage) + of advertise: + message.advertise = SSZ.decode(body.toOpenArray(1, body.high), AdvertiseMessage) + of requestproofs: + message.requestproofs = SSZ.decode(body.toOpenArray(1, body.high), RequestProofsMessage) + except SszError: + return err("Invalid message encoding") + + ok(message) diff --git a/eth/p2p/portal/portalcli.nim b/eth/p2p/portal/portalcli.nim new file mode 100644 index 0000000..30e624d --- /dev/null +++ b/eth/p2p/portal/portalcli.nim @@ -0,0 +1,221 @@ +# nim-eth - Portal Network +# Copyright (c) 2021 Status Research & Development GmbH +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +import + std/[options, strutils, tables], + confutils, confutils/std/net, chronicles, chronicles/topics_registry, + chronos, metrics, metrics/chronos_httpserver, stew/byteutils, + ../../keys, ../../net/nat, + ".."/discoveryv5/[enr, node], ".."/discoveryv5/protocol as discv5_protocol, + ./messages, ./protocol as portal_protocol + +type + PortalCmd* = enum + noCommand + ping + findnode + findcontent + + DiscoveryConf* = object + logLevel* {. + defaultValue: LogLevel.DEBUG + desc: "Sets the log level" + name: "log-level" .}: LogLevel + + udpPort* {. + defaultValue: 9009 + desc: "UDP listening port" + name: "udp-port" .}: uint16 + + listenAddress* {. + defaultValue: defaultListenAddress(config) + desc: "Listening address for the Discovery v5 traffic" + name: "listen-address" }: ValidIpAddress + + bootnodes* {. + desc: "ENR URI of node to bootstrap discovery with. Argument may be repeated" + name: "bootnode" .}: seq[enr.Record] + + nat* {. + desc: "Specify method to use for determining public address. " & + "Must be one of: any, none, upnp, pmp, extip:" + defaultValue: NatConfig(hasExtIp: false, nat: NatAny) + name: "nat" .}: NatConfig + + enrAutoUpdate* {. + defaultValue: false + desc: "Discovery can automatically update its ENR with the IP address " & + "and UDP port as seen by other nodes it communicates with. " & + "This option allows to enable/disable this functionality" + name: "enr-auto-update" .}: bool + + nodeKey* {. + desc: "P2P node private key as hex", + defaultValue: PrivateKey.random(keys.newRng()[]) + name: "nodekey" .}: PrivateKey + + metricsEnabled* {. + defaultValue: false + desc: "Enable the metrics server" + name: "metrics" .}: bool + + metricsAddress* {. + defaultValue: defaultAdminListenAddress(config) + desc: "Listening address of the metrics server" + name: "metrics-address" .}: ValidIpAddress + + metricsPort* {. + defaultValue: 8008 + desc: "Listening HTTP port of the metrics server" + name: "metrics-port" .}: Port + + case cmd* {. + command + defaultValue: noCommand }: PortalCmd + of noCommand: + discard + of ping: + pingTarget* {. + argument + desc: "ENR URI of the node to a send ping message" + name: "node" .}: Node + of findnode: + distance* {. + defaultValue: 255 + desc: "Distance parameter for the findNode message" + name: "distance" .}: uint16 + # TODO: Order here matters as else the help message does not show all the + # information, see: https://github.com/status-im/nim-confutils/issues/15 + findNodeTarget* {. + argument + desc: "ENR URI of the node to send a findNode message" + name: "node" .}: Node + of findcontent: + findContentTarget* {. + argument + desc: "ENR URI of the node to send a findContent message" + name: "node" .}: Node + +func defaultListenAddress*(conf: DiscoveryConf): ValidIpAddress = + (static ValidIpAddress.init("0.0.0.0")) + +func defaultAdminListenAddress*(conf: DiscoveryConf): ValidIpAddress = + (static ValidIpAddress.init("127.0.0.1")) + +proc parseCmdArg*(T: type enr.Record, p: TaintedString): T = + if not fromURI(result, p): + raise newException(ConfigurationError, "Invalid ENR") + +proc completeCmdArg*(T: type enr.Record, val: TaintedString): seq[string] = + return @[] + +proc parseCmdArg*(T: type Node, p: TaintedString): T = + var record: enr.Record + if not fromURI(record, p): + raise newException(ConfigurationError, "Invalid ENR") + + let n = newNode(record) + if n.isErr: + raise newException(ConfigurationError, $n.error) + + if n[].address.isNone(): + raise newException(ConfigurationError, "ENR without address") + + n[] + +proc completeCmdArg*(T: type Node, val: TaintedString): seq[string] = + return @[] + +proc parseCmdArg*(T: type PrivateKey, p: TaintedString): T = + try: + result = PrivateKey.fromHex(string(p)).tryGet() + except CatchableError: + raise newException(ConfigurationError, "Invalid private key") + +proc completeCmdArg*(T: type PrivateKey, val: TaintedString): seq[string] = + return @[] + +proc discover(d: discv5_protocol.Protocol) {.async.} = + while true: + let discovered = await d.queryRandom() + info "Lookup finished", nodes = discovered.len + await sleepAsync(30.seconds) + +proc run(config: DiscoveryConf) = + let + rng = newRng() + bindIp = config.listenAddress + udpPort = Port(config.udpPort) + # TODO: allow for no TCP port mapping! + (extIp, _, extUdpPort) = setupAddress(config.nat, + config.listenAddress, udpPort, udpPort, "dcli") + + let d = newProtocol(config.nodeKey, + extIp, none(Port), extUdpPort, + bootstrapRecords = config.bootnodes, + bindIp = bindIp, bindPort = udpPort, + enrAutoUpdate = config.enrAutoUpdate, + rng = rng) + + d.open() + + let portal = PortalProtocol.new(d) + + if config.metricsEnabled: + let + address = config.metricsAddress + port = config.metricsPort + notice "Starting metrics HTTP server", + url = "http://" & $address & ":" & $port & "/metrics" + try: + chronos_httpserver.startMetricsHttpServer($address, port) + except CatchableError as exc: raise exc + except Exception as exc: raiseAssert exc.msg # TODO fix metrics + + case config.cmd + of ping: + let pong = waitFor portal.ping(config.pingTarget) + + if pong.isOk(): + echo pong.get() + else: + echo pong.error + of findnode: + let distances = List[uint16, 256](@[config.distance]) + let nodes = waitFor portal.findNode(config.findNodeTarget, distances) + + if nodes.isOk(): + echo nodes.get() + else: + echo nodes.error + of findcontent: + proc random(T: type UInt256, rng: var BrHmacDrbgContext): T = + var key: UInt256 + brHmacDrbgGenerate(addr rng, addr key, csize_t(sizeof(key))) + + key + + # For now just random content keys + let contentKey = ByteList(@(UInt256.random(rng[]).toBytes())) + let foundContent = waitFor portal.findContent(config.findContentTarget, + contentKey) + + if foundContent.isOk(): + echo foundContent.get() + else: + echo foundContent.error + + of noCommand: + d.start() + waitfor(discover(d)) + +when isMainModule: + let config = DiscoveryConf.load() + + setLogLevel(config.logLevel) + + run(config) diff --git a/eth/p2p/portal/protocol.nim b/eth/p2p/portal/protocol.nim new file mode 100644 index 0000000..f56eee4 --- /dev/null +++ b/eth/p2p/portal/protocol.nim @@ -0,0 +1,164 @@ +# nim-eth - Portal Network +# Copyright (c) 2021 Status Research & Development GmbH +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +{.push raises: [Defect].} + +import + stew/[results, byteutils], chronicles, + ../../rlp, + ../discoveryv5/[protocol, node], + ./messages + +export messages + +logScope: + topics = "portal" + +const + PortalProtocolId* = "portal".toBytes() + +type + PortalProtocol* = ref object of TalkProtocol + baseProtocol*: protocol.Protocol + dataRadius*: UInt256 + +proc handlePing(p: PortalProtocol, ping: PingMessage): + seq[byte] = + let p = PongMessage(enrSeq: p.baseProtocol.localNode.record.seqNum, + dataRadius: p.dataRadius) + + encodeMessage(p) + +proc handleFindNode(p: PortalProtocol, fn: FindNodeMessage): seq[byte] = + if fn.distances.len == 0: + let enrs = List[ByteList, 32](@[]) + encodeMessage(NodesMessage(total: 1, enrs: enrs)) + elif fn.distances.contains(0): + # A request for our own record. + let enr = ByteList(rlp.encode(p.baseProtocol.localNode.record)) + encodeMessage(NodesMessage(total: 1, enrs: List[ByteList, 32](@[enr]))) + else: + # TODO: Not implemented for now, sending empty back. + let enrs = List[ByteList, 32](@[]) + encodeMessage(NodesMessage(total: 1, enrs: enrs)) + +proc handleFindContent(p: PortalProtocol, ping: FindContentMessage): seq[byte] = + # TODO: Neither payload nor enrs implemented, sending empty back. + let + enrs = List[ByteList, 32](@[]) + payload = ByteList(@[]) + encodeMessage(FoundContentMessage(enrs: enrs, payload: payload)) + +proc handleAdvertise(p: PortalProtocol, ping: AdvertiseMessage): seq[byte] = + # TODO: Not implemented + let + connectionId = List[byte, 4](@[]) + contentKeys = List[ByteList, 32](@[]) + encodeMessage(RequestProofsMessage(connectionId: connectionId, + contentKeys: contentKeys)) + +proc messageHandler*(protocol: TalkProtocol, request: seq[byte]): seq[byte] = + doAssert(protocol of PortalProtocol) + + let p = PortalProtocol(protocol) + + let decoded = decodeMessage(request) + if decoded.isOk(): + let message = decoded.get() + trace "Received message response", kind = message.kind + case message.kind + of MessageKind.ping: + p.handlePing(message.ping) + of MessageKind.findnode: + p.handleFindNode(message.findNode) + of MessageKind.findcontent: + p.handleFindContent(message.findcontent) + of MessageKind.advertise: + p.handleAdvertise(message.advertise) + else: + @[] + else: + @[] + +proc new*(T: type PortalProtocol, baseProtocol: protocol.Protocol, + dataRadius = UInt256.high()): T = + let proto = PortalProtocol( + protocolHandler: messageHandler, + baseProtocol: baseProtocol, + dataRadius: dataRadius) + + proto.baseProtocol.registerTalkProtocol(PortalProtocolId, proto).expect( + "Only one protocol should have this id") + + return proto + +proc ping*(p: PortalProtocol, dst: Node): + Future[DiscResult[PongMessage]] {.async.} = + let ping = PingMessage(enrSeq: p.baseProtocol.localNode.record.seqNum, + dataRadius: p.dataRadius) + + # TODO: This send and response handling code could be more generalized for the + # different message types. + trace "Send message request", dstId = dst.id, kind = MessageKind.ping + let talkresp = await talkreq(p.baseProtocol, dst, PortalProtocolId, + encodeMessage(ping)) + + if talkresp.isOk(): + let decoded = decodeMessage(talkresp.get().response) + if decoded.isOk(): + let message = decoded.get() + if message.kind == pong: + return ok(message.pong) + else: + return err("Invalid message response received") + else: + return err(decoded.error) + else: + return err(talkresp.error) + +proc findNode*(p: PortalProtocol, dst: Node, distances: List[uint16, 256]): + Future[DiscResult[NodesMessage]] {.async.} = + let fn = FindNodeMessage(distances: distances) + + trace "Send message request", dstId = dst.id, kind = MessageKind.findnode + let talkresp = await talkreq(p.baseProtocol, dst, PortalProtocolId, + encodeMessage(fn)) + + if talkresp.isOk(): + let decoded = decodeMessage(talkresp.get().response) + if decoded.isOk(): + let message = decoded.get() + if message.kind == nodes: + # TODO: Verify nodes here + return ok(message.nodes) + else: + return err("Invalid message response received") + else: + return err(decoded.error) + else: + return err(talkresp.error) + +proc findContent*(p: PortalProtocol, dst: Node, contentKey: ByteList): + Future[DiscResult[FoundContentMessage]] {.async.} = + let fc = FindContentMessage(contentKey: contentKey) + + trace "Send message request", dstId = dst.id, kind = MessageKind.findcontent + let talkresp = await talkreq(p.baseProtocol, dst, PortalProtocolId, + encodeMessage(fc)) + + if talkresp.isOk(): + let decoded = decodeMessage(talkresp.get().response) + if decoded.isOk(): + let message = decoded.get() + if message.kind == foundcontent: + return ok(message.foundcontent) + else: + return err("Invalid message response received") + else: + return err(decoded.error) + else: + return err(talkresp.error) diff --git a/eth/ssz/bitseqs.nim b/eth/ssz/bitseqs.nim new file mode 100644 index 0000000..c4ca4f1 --- /dev/null +++ b/eth/ssz/bitseqs.nim @@ -0,0 +1,313 @@ +# nim-eth +# Copyright (c) 2018-2021 Status Research & Development GmbH +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +{.push raises: [Defect].} + +import + stew/[bitops2, endians2, ptrops] + +type + Bytes = seq[byte] + + BitSeq* = distinct Bytes + ## The current design of BitSeq tries to follow precisely + ## the bitwise representation of the SSZ bitlists. + ## This is a relatively compact representation, but as + ## evident from the code below, many of the operations + ## are not trivial. + + BitArray*[bits: static int] = object + bytes*: array[(bits + 7) div 8, byte] + +func bitsLen*(bytes: openArray[byte]): int = + let + bytesCount = bytes.len + lastByte = bytes[bytesCount - 1] + markerPos = log2trunc(lastByte) + + bytesCount * 8 - (8 - markerPos) + +template len*(s: BitSeq): int = + bitsLen(Bytes s) + +template len*(a: BitArray): int = + a.bits + +func add*(s: var BitSeq, value: bool) = + let + lastBytePos = s.Bytes.len - 1 + lastByte = s.Bytes[lastBytePos] + + if (lastByte and byte(128)) == 0: + # There is at least one leading zero, so we have enough + # room to store the new bit + let markerPos = log2trunc(lastByte) + s.Bytes[lastBytePos].changeBit markerPos, value + s.Bytes[lastBytePos].setBit markerPos + 1 + else: + s.Bytes[lastBytePos].changeBit 7, value + s.Bytes.add byte(1) + +func toBytesLE(x: uint): array[sizeof(x), byte] = + # stew/endians2 supports explicitly sized uints only + when sizeof(uint) == 4: + static: doAssert sizeof(uint) == sizeof(uint32) + toBytesLE(x.uint32) + elif sizeof(uint) == 8: + static: doAssert sizeof(uint) == sizeof(uint64) + toBytesLE(x.uint64) + else: + static: doAssert false, "requires a 32-bit or 64-bit platform" + +func loadLEBytes(WordType: type, bytes: openArray[byte]): WordType = + # TODO: this is a temporary proc until the endians API is improved + var shift = 0 + for b in bytes: + result = result or (WordType(b) shl shift) + shift += 8 + +func storeLEBytes(value: SomeUnsignedInt, dst: var openArray[byte]) = + doAssert dst.len <= sizeof(value) + let bytesLE = toBytesLE(value) + copyMem(addr dst[0], unsafeAddr bytesLE[0], dst.len) + +template loopOverWords(lhs, rhs: BitSeq, + lhsIsVar, rhsIsVar: static bool, + WordType: type, + lhsBits, rhsBits, body: untyped) = + const hasRhs = astToStr(lhs) != astToStr(rhs) + + let bytesCount = len Bytes(lhs) + when hasRhs: doAssert len(Bytes(rhs)) == bytesCount + + var fullWordsCount = bytesCount div sizeof(WordType) + let lastWordSize = bytesCount mod sizeof(WordType) + + block: + var lhsWord: WordType + when hasRhs: + var rhsWord: WordType + var firstByteOfLastWord, lastByteOfLastWord: int + + # TODO: Returning a `var` value from an iterator is always safe due to + # the way inlining works, but currently the compiler reports an error + # when a local variable escapes. We have to cheat it with this location + # obfuscation through pointers: + template lhsBits: auto = (addr(lhsWord))[] + + when hasRhs: + template rhsBits: auto = (addr(rhsWord))[] + + template lastWordBytes(bitseq): auto = + Bytes(bitseq).toOpenArray(firstByteOfLastWord, lastByteOfLastWord) + + template initLastWords = + lhsWord = loadLEBytes(WordType, lastWordBytes(lhs)) + when hasRhs: rhsWord = loadLEBytes(WordType, lastWordBytes(rhs)) + + if lastWordSize == 0: + firstByteOfLastWord = bytesCount - sizeof(WordType) + lastByteOfLastWord = bytesCount - 1 + dec fullWordsCount + else: + firstByteOfLastWord = bytesCount - lastWordSize + lastByteOfLastWord = bytesCount - 1 + + initLastWords() + let markerPos = log2trunc(lhsWord) + when hasRhs: doAssert log2trunc(rhsWord) == markerPos + + lhsWord.clearBit markerPos + when hasRhs: rhsWord.clearBit markerPos + + body + + when lhsIsVar or rhsIsVar: + let + markerBit = uint(1 shl markerPos) + mask = markerBit - 1'u + + when lhsIsVar: + let lhsEndResult = (lhsWord and mask) or markerBit + storeLEBytes(lhsEndResult, lastWordBytes(lhs)) + + when rhsIsVar: + let rhsEndResult = (rhsWord and mask) or markerBit + storeLEBytes(rhsEndResult, lastWordBytes(rhs)) + + var lhsCurrAddr = cast[ptr WordType](unsafeAddr Bytes(lhs)[0]) + let lhsEndAddr = offset(lhsCurrAddr, fullWordsCount) + when hasRhs: + var rhsCurrAddr = cast[ptr WordType](unsafeAddr Bytes(rhs)[0]) + + while lhsCurrAddr < lhsEndAddr: + template lhsBits: auto = lhsCurrAddr[] + when hasRhs: + template rhsBits: auto = rhsCurrAddr[] + + body + + lhsCurrAddr = offset(lhsCurrAddr, 1) + when hasRhs: rhsCurrAddr = offset(rhsCurrAddr, 1) + +iterator words*(x: var BitSeq): var uint = + loopOverWords(x, x, true, false, uint, word, wordB): + yield word + +iterator words*(x: BitSeq): uint = + loopOverWords(x, x, false, false, uint, word, word): + yield word + +iterator words*(a, b: BitSeq): (uint, uint) = + loopOverWords(a, b, false, false, uint, wordA, wordB): + yield (wordA, wordB) + +iterator words*(a: var BitSeq, b: BitSeq): (var uint, uint) = + loopOverWords(a, b, true, false, uint, wordA, wordB): + yield (wordA, wordB) + +iterator words*(a, b: var BitSeq): (var uint, var uint) = + loopOverWords(a, b, true, true, uint, wordA, wordB): + yield (wordA, wordB) + +func `[]`*(s: BitSeq, pos: Natural): bool {.inline.} = + doAssert pos < s.len + s.Bytes.getBit pos + +func `[]=`*(s: var BitSeq, pos: Natural, value: bool) {.inline.} = + doAssert pos < s.len + s.Bytes.changeBit pos, value + +func setBit*(s: var BitSeq, pos: Natural) {.inline.} = + doAssert pos < s.len + setBit s.Bytes, pos + +func clearBit*(s: var BitSeq, pos: Natural) {.inline.} = + doAssert pos < s.len + clearBit s.Bytes, pos + +func init*(T: type BitSeq, len: int): T = + result = BitSeq newSeq[byte](1 + len div 8) + Bytes(result).setBit len + +func init*(T: type BitArray): T = + # The default zero-initializatio is fine + discard + +template `[]`*(a: BitArray, pos: Natural): bool = + getBit a.bytes, pos + +template `[]=`*(a: var BitArray, pos: Natural, value: bool) = + changeBit a.bytes, pos, value + +template setBit*(a: var BitArray, pos: Natural) = + setBit a.bytes, pos + +template clearBit*(a: var BitArray, pos: Natural) = + clearBit a.bytes, pos + +# TODO: Submit this to the standard library as `cmp` +# At the moment, it doesn't work quite well because Nim selects +# the generic cmp[T] from the system module instead of choosing +# the openArray overload +func compareArrays[T](a, b: openArray[T]): int = + result = cmp(a.len, b.len) + if result != 0: return + + for i in 0 ..< a.len: + result = cmp(a[i], b[i]) + if result != 0: return + +template cmp*(a, b: BitSeq): int = + compareArrays(Bytes a, Bytes b) + +template `==`*(a, b: BitSeq): bool = + cmp(a, b) == 0 + +func `$`*(a: BitSeq | BitArray): string = + let length = a.len + result = newStringOfCap(2 + length) + result.add "0b" + for i in countdown(length - 1, 0): + result.add if a[i]: '1' else: '0' + +func incl*(tgt: var BitSeq, src: BitSeq) = + # Update `tgt` to include the bits of `src`, as if applying `or` to each bit + doAssert tgt.len == src.len + for tgtWord, srcWord in words(tgt, src): + tgtWord = tgtWord or srcWord + +func overlaps*(a, b: BitSeq): bool = + for wa, wb in words(a, b): + if (wa and wb) != 0: + return true + +func countOverlap*(a, b: BitSeq): int = + var res = 0 + for wa, wb in words(a, b): + res += countOnes(wa and wb) + res + +func isSubsetOf*(a, b: BitSeq): bool = + let alen = a.len + doAssert b.len == alen + for i in 0 ..< alen: + if a[i] and not b[i]: + return false + true + +func isZeros*(x: BitSeq): bool = + for w in words(x): + if w != 0: return false + return true + +func countOnes*(x: BitSeq): int = + # Count the number of set bits + var res = 0 + for w in words(x): + res += w.countOnes() + res + +func clear*(x: var BitSeq) = + for w in words(x): + w = 0 + +func countZeros*(x: BitSeq): int = + x.len() - x.countOnes() + +template bytes*(x: BitSeq): untyped = + seq[byte](x) + +iterator items*(x: BitArray): bool = + for i in 0.. byte(1): + raise newException(MalformedSszError, "invalid boolean value") + data[0] == 1 + +template fromSszBytes*(T: type BitSeq, bytes: openArray[byte]): auto = + BitSeq @bytes + +proc `[]`[T, U, V](s: openArray[T], x: HSlice[U, V]) {.error: + "Please don't use openArray's [] as it allocates a result sequence".} + +template checkForForbiddenBits(ResulType: type, + input: openArray[byte], + expectedBits: static int64) = + ## This checks if the input contains any bits set above the maximum + ## sized allowed. We only need to check the last byte to verify this: + const bitsInLastByte = (expectedBits mod 8) + when bitsInLastByte != 0: + # As an example, if there are 3 bits expected in the last byte, + # we calculate a bitmask equal to 11111000. If the input has any + # raised bits in range of the bitmask, this would be a violation + # of the size of the BitArray: + const forbiddenBitsMask = byte(byte(0xff) shl bitsInLastByte) + + if (input[^1] and forbiddenBitsMask) != 0: + raiseIncorrectSize ResulType + +func readSszValue*[T](input: openArray[byte], val: var T) + {.raises: [SszError, Defect].} = + mixin fromSszBytes, toSszType + + template readOffsetUnchecked(n: int): uint32 {.used.}= + fromSszBytes(uint32, input.toOpenArray(n, n + offsetSize - 1)) + + template readOffset(n: int): int {.used.} = + let offset = readOffsetUnchecked(n) + if offset > input.len.uint32: + raise newException(MalformedSszError, "SSZ list element offset points past the end of the input") + int(offset) + + when val is BitList: + if input.len == 0: + raise newException(MalformedSszError, "Invalid empty SSZ BitList value") + + # Since our BitLists have an in-memory representation that precisely + # matches their SSZ encoding, we can deserialize them as regular Lists: + const maxExpectedSize = (val.maxLen div 8) + 1 + type MatchingListType = List[byte, maxExpectedSize] + + when false: + # TODO: Nim doesn't like this simple type coercion, + # we'll rely on `cast` for now (see below) + readSszValue(input, MatchingListType val) + else: + static: + # As a sanity check, we verify that the coercion is accepted by the compiler: + doAssert MatchingListType(val) is MatchingListType + readSszValue(input, cast[ptr MatchingListType](addr val)[]) + + let resultBytesCount = len bytes(val) + + if bytes(val)[resultBytesCount - 1] == 0: + raise newException(MalformedSszError, "SSZ BitList is not properly terminated") + + if resultBytesCount == maxExpectedSize: + checkForForbiddenBits(T, input, val.maxLen + 1) + + elif val is List|array: + type E = type val[0] + + when E is byte: + val.setOutputSize input.len + if input.len > 0: + copyMem(addr val[0], unsafeAddr input[0], input.len) + + elif isFixedSize(E): + const elemSize = fixedPortionSize(E) + if input.len mod elemSize != 0: + var ex = new SszSizeMismatchError + ex.deserializedType = cstring typetraits.name(T) + ex.actualSszSize = input.len + ex.elementSize = elemSize + raise ex + val.setOutputSize input.len div elemSize + for i in 0 ..< val.len: + let offset = i * elemSize + readSszValue(input.toOpenArray(offset, offset + elemSize - 1), val[i]) + + else: + if input.len == 0: + # This is an empty list. + # The default initialization of the return value is fine. + val.setOutputSize 0 + return + elif input.len < offsetSize: + raise newException(MalformedSszError, "SSZ input of insufficient size") + + var offset = readOffset 0 + let resultLen = offset div offsetSize + + if resultLen == 0: + # If there are too many elements, other constraints detect problems + # (not monotonically increasing, past end of input, or last element + # not matching up with its nextOffset properly) + raise newException(MalformedSszError, "SSZ list incorrectly encoded of zero length") + + val.setOutputSize resultLen + for i in 1 ..< resultLen: + let nextOffset = readOffset(i * offsetSize) + if nextOffset <= offset: + raise newException(MalformedSszError, "SSZ list element offsets are not monotonically increasing") + else: + readSszValue(input.toOpenArray(offset, nextOffset - 1), val[i - 1]) + offset = nextOffset + + readSszValue(input.toOpenArray(offset, input.len - 1), val[resultLen - 1]) + + elif val is UintN|bool: + val = fromSszBytes(T, input) + + elif val is BitArray: + if sizeof(val) != input.len: + raiseIncorrectSize(T) + checkForForbiddenBits(T, input, val.bits) + copyMem(addr val.bytes[0], unsafeAddr input[0], input.len) + + elif val is object|tuple: + let inputLen = uint32 input.len + const minimallyExpectedSize = uint32 fixedPortionSize(T) + + if inputLen < minimallyExpectedSize: + raise newException(MalformedSszError, "SSZ input of insufficient size") + + enumInstanceSerializedFields(val, fieldName, field): + const boundingOffsets = getFieldBoundingOffsets(T, fieldName) + + # type FieldType = type field # buggy + # For some reason, Nim gets confused about the alias here. This could be a + # generics caching issue caused by the use of distinct types. Such an + # issue is very scary in general. + # The bug can be seen with the two List[uint64, N] types that exist in + # the spec, with different N. + + type SszType = type toSszType(declval type(field)) + + when isFixedSize(SszType): + const + startOffset = boundingOffsets[0] + endOffset = boundingOffsets[1] + else: + let + startOffset = readOffsetUnchecked(boundingOffsets[0]) + endOffset = if boundingOffsets[1] == -1: inputLen + else: readOffsetUnchecked(boundingOffsets[1]) + + when boundingOffsets.isFirstOffset: + if startOffset != minimallyExpectedSize: + raise newException(MalformedSszError, "SSZ object dynamic portion starts at invalid offset") + + if startOffset > endOffset: + raise newException(MalformedSszError, "SSZ field offsets are not monotonically increasing") + elif endOffset > inputLen: + raise newException(MalformedSszError, "SSZ field offset points past the end of the input") + elif startOffset < minimallyExpectedSize: + raise newException(MalformedSszError, "SSZ field offset points outside bounding offsets") + + # TODO The extra type escaping here is a work-around for a Nim issue: + when type(field) is type(SszType): + readSszValue( + input.toOpenArray(int(startOffset), int(endOffset - 1)), + field) + else: + field = fromSszBytes( + type(field), + input.toOpenArray(int(startOffset), int(endOffset - 1))) + + else: + unsupported T diff --git a/eth/ssz/ssz_serialization.nim b/eth/ssz/ssz_serialization.nim new file mode 100644 index 0000000..2c42181 --- /dev/null +++ b/eth/ssz/ssz_serialization.nim @@ -0,0 +1,247 @@ +# nim-eth - Limited SSZ implementation +# Copyright (c) 2018-2021 Status Research & Development GmbH +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +{.push raises: [Defect].} + +## SSZ serialization for core SSZ types, as specified in: +# https://github.com/ethereum/eth2.0-specs/blob/v1.0.1/ssz/simple-serialize.md#serialization + +import + std/[typetraits, options], + stew/[endians2, leb128, objects], + serialization, serialization/testing/tracing, + ./bytes_reader, ./types + +export + serialization, types, bytes_reader + +type + SszReader* = object + stream: InputStream + + SszWriter* = object + stream: OutputStream + + SizePrefixed*[T] = distinct T + SszMaxSizeExceeded* = object of SerializationError + + VarSizedWriterCtx = object + fixedParts: WriteCursor + offset: int + + FixedSizedWriterCtx = object + +serializationFormat SSZ + +SSZ.setReader SszReader +SSZ.setWriter SszWriter, PreferredOutput = seq[byte] + +template sizePrefixed*[TT](x: TT): untyped = + type T = TT + SizePrefixed[T](x) + +proc init*(T: type SszReader, stream: InputStream): T {.raises: [Defect].} = + T(stream: stream) + +proc writeFixedSized(s: var (OutputStream|WriteCursor), x: auto) + {.raises: [Defect, IOError].} = + mixin toSszType + + when x is byte: + s.write x + elif x is bool: + s.write byte(ord(x)) + elif x is UintN: + when cpuEndian == bigEndian: + s.write toBytesLE(x) + else: + s.writeMemCopy x + elif x is array: + when x[0] is byte: + trs "APPENDING FIXED SIZE BYTES", x + s.write x + else: + for elem in x: + trs "WRITING FIXED SIZE ARRAY ELEMENT" + s.writeFixedSized toSszType(elem) + elif x is tuple|object: + enumInstanceSerializedFields(x, fieldName, field): + trs "WRITING FIXED SIZE FIELD", fieldName + s.writeFixedSized toSszType(field) + else: + unsupported x.type + +template writeOffset(cursor: var WriteCursor, offset: int) = + write cursor, toBytesLE(uint32 offset) + +template supports*(_: type SSZ, T: type): bool = + mixin toSszType + anonConst compiles(fixedPortionSize toSszType(declval T)) + +func init*(T: type SszWriter, stream: OutputStream): T {.raises: [Defect].} = + result.stream = stream + +proc writeVarSizeType(w: var SszWriter, value: auto) + {.gcsafe, raises: [Defect, IOError].} + +proc beginRecord*(w: var SszWriter, TT: type): auto {.raises: [Defect].} = + type T = TT + when isFixedSize(T): + FixedSizedWriterCtx() + else: + const offset = when T is array: len(T) * offsetSize + else: fixedPortionSize(T) + VarSizedWriterCtx(offset: offset, + fixedParts: w.stream.delayFixedSizeWrite(offset)) + +template writeField*(w: var SszWriter, + ctx: var auto, + fieldName: string, + field: auto) = + mixin toSszType + when ctx is FixedSizedWriterCtx: + writeFixedSized(w.stream, toSszType(field)) + else: + type FieldType = type toSszType(field) + + when isFixedSize(FieldType): + writeFixedSized(ctx.fixedParts, toSszType(field)) + else: + trs "WRITING OFFSET ", ctx.offset, " FOR ", fieldName + writeOffset(ctx.fixedParts, ctx.offset) + let initPos = w.stream.pos + trs "WRITING VAR SIZE VALUE OF TYPE ", name(FieldType) + when FieldType is BitList: + trs "BIT SEQ ", bytes(field) + writeVarSizeType(w, toSszType(field)) + ctx.offset += w.stream.pos - initPos + +template endRecord*(w: var SszWriter, ctx: var auto) = + when ctx is VarSizedWriterCtx: + finalize ctx.fixedParts + +proc writeSeq[T](w: var SszWriter, value: seq[T]) + {.raises: [Defect, IOError].} = + # Please note that `writeSeq` exists in order to reduce the code bloat + # produced from generic instantiations of the unique `List[N, T]` types. + when isFixedSize(T): + trs "WRITING LIST WITH FIXED SIZE ELEMENTS" + for elem in value: + w.stream.writeFixedSized toSszType(elem) + trs "DONE" + else: + trs "WRITING LIST WITH VAR SIZE ELEMENTS" + var offset = value.len * offsetSize + var cursor = w.stream.delayFixedSizeWrite offset + for elem in value: + cursor.writeFixedSized uint32(offset) + let initPos = w.stream.pos + w.writeVarSizeType toSszType(elem) + offset += w.stream.pos - initPos + finalize cursor + trs "DONE" + +proc writeVarSizeType(w: var SszWriter, value: auto) + {.raises: [Defect, IOError].} = + trs "STARTING VAR SIZE TYPE" + + when value is List: + # We reduce code bloat by forwarding all `List` types to a general `seq[T]` + # proc. + writeSeq(w, asSeq value) + elif value is BitList: + # ATTENTION! We can reuse `writeSeq` only as long as our BitList type is + # implemented to internally match the binary representation of SSZ BitLists + # in memory. + writeSeq(w, bytes value) + elif value is object|tuple|array: + trs "WRITING OBJECT OR ARRAY" + var ctx = beginRecord(w, type value) + enumerateSubFields(value, field): + writeField w, ctx, astToStr(field), field + endRecord w, ctx + else: + unsupported type(value) + +proc writeValue*(w: var SszWriter, x: auto) + {.gcsafe, raises: [Defect, IOError].} = + mixin toSszType + type T = type toSszType(x) + + when isFixedSize(T): + w.stream.writeFixedSized toSszType(x) + else: + w.writeVarSizeType toSszType(x) + +func sszSize*(value: auto): int {.gcsafe, raises: [Defect].} + +func sszSizeForVarSizeList[T](value: openArray[T]): int = + mixin toSszType + result = len(value) * offsetSize + for elem in value: + result += sszSize(toSszType elem) + +func sszSize*(value: auto): int {.gcsafe, raises: [Defect].} = + mixin toSszType + type T = type toSszType(value) + + when isFixedSize(T): + anonConst fixedPortionSize(T) + + elif T is array|List: + type E = ElemType(T) + when isFixedSize(E): + len(value) * anonConst(fixedPortionSize(E)) + elif T is HashArray: + sszSizeForVarSizeList(value.data) + elif T is array: + sszSizeForVarSizeList(value) + else: + sszSizeForVarSizeList(asSeq value) + + elif T is BitList: + return len(bytes(value)) + + elif T is object|tuple: + result = anonConst fixedPortionSize(T) + enumInstanceSerializedFields(value, _{.used.}, field): + type FieldType = type toSszType(field) + when not isFixedSize(FieldType): + result += sszSize(toSszType field) + + else: + unsupported T + +proc writeValue*[T](w: var SszWriter, x: SizePrefixed[T]) + {.raises: [Defect, IOError].} = + var cursor = w.stream.delayVarSizeWrite(Leb128.maxLen(uint64)) + let initPos = w.stream.pos + w.writeValue T(x) + let length = toBytes(uint64(w.stream.pos - initPos), Leb128) + cursor.finalWrite length.toOpenArray() + +proc readValue*[T](r: var SszReader, val: var T) + {.raises: [Defect, SszError, IOError].} = + when isFixedSize(T): + const minimalSize = fixedPortionSize(T) + if r.stream.readable(minimalSize): + readSszValue(r.stream.read(minimalSize), val) + else: + raise newException(MalformedSszError, "SSZ input of insufficient size") + else: + # TODO(zah) Read the fixed portion first and precisely measure the + # size of the dynamic portion to consume the right number of bytes. + readSszValue(r.stream.read(r.stream.len.get), val) + +proc readSszBytes*[T](data: openArray[byte], val: var T) {. + raises: [Defect, MalformedSszError, SszSizeMismatchError].} = + when isFixedSize(T): + const minimalSize = fixedPortionSize(T) + if data.len < minimalSize: + raise newException(MalformedSszError, "SSZ input of insufficient size") + + readSszValue(data, val) diff --git a/eth/ssz/types.nim b/eth/ssz/types.nim new file mode 100644 index 0000000..ec7ba6e --- /dev/null +++ b/eth/ssz/types.nim @@ -0,0 +1,258 @@ +# nim-eth - Limited SSZ implementation +# Copyright (c) 2018-2021 Status Research & Development GmbH +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +{.push raises: [Defect].} + +import + std/[tables, options, typetraits, strformat], + stew/shims/macros, stew/[byteutils, bitops2, objects], + serialization/[object_serialization, errors], + ./bitseqs + +export bitseqs + +const + offsetSize* = 4 + bytesPerChunk* = 32 + +type + UintN* = SomeUnsignedInt + BasicType* = bool|UintN + + Limit* = int64 + + List*[T; maxLen: static Limit] = distinct seq[T] + BitList*[maxLen: static Limit] = distinct BitSeq + + # Note for readers: + # We use `array` for `Vector` and + # `BitArray` for `BitVector` + + SszError* = object of SerializationError + + MalformedSszError* = object of SszError + + SszSizeMismatchError* = object of SszError + deserializedType*: cstring + actualSszSize*: int + elementSize*: int + +template asSeq*(x: List): auto = distinctBase(x) + +template init*[T](L: type List, x: seq[T], N: static Limit): auto = + List[T, N](x) + +template init*[T, N](L: type List[T, N], x: seq[T]): auto = + List[T, N](x) + +template `$`*(x: List): auto = $(distinctBase x) +template len*(x: List): auto = len(distinctBase x) +template low*(x: List): auto = low(distinctBase x) +template high*(x: List): auto = high(distinctBase x) +template `[]`*(x: List, idx: auto): untyped = distinctBase(x)[idx] +template `[]=`*(x: var List, idx: auto, val: auto) = distinctBase(x)[idx] = val +template `==`*(a, b: List): bool = distinctBase(a) == distinctBase(b) + +template `&`*(a, b: List): auto = (type(a)(distinctBase(a) & distinctBase(b))) + +template items* (x: List): untyped = items(distinctBase x) +template pairs* (x: List): untyped = pairs(distinctBase x) +template mitems*(x: var List): untyped = mitems(distinctBase x) +template mpairs*(x: var List): untyped = mpairs(distinctBase x) + +template contains* (x: List, val: auto): untyped = contains(distinctBase x, val) + +proc add*(x: var List, val: auto): bool = + if x.len < x.maxLen: + add(distinctBase x, val) + true + else: + false + +proc setLen*(x: var List, newLen: int): bool = + if newLen <= x.maxLen: + setLen(distinctBase x, newLen) + true + else: + false + +template init*(L: type BitList, x: seq[byte], N: static Limit): auto = + BitList[N](data: x) + +template init*[N](L: type BitList[N], x: seq[byte]): auto = + L(data: x) + +template init*(T: type BitList, len: int): auto = T init(BitSeq, len) +template len*(x: BitList): auto = len(BitSeq(x)) +template bytes*(x: BitList): auto = seq[byte](x) +template `[]`*(x: BitList, idx: auto): auto = BitSeq(x)[idx] +template `[]=`*(x: var BitList, idx: auto, val: bool) = BitSeq(x)[idx] = val +template `==`*(a, b: BitList): bool = BitSeq(a) == BitSeq(b) +template setBit*(x: var BitList, idx: Natural) = setBit(BitSeq(x), idx) +template clearBit*(x: var BitList, idx: Natural) = clearBit(BitSeq(x), idx) +template overlaps*(a, b: BitList): bool = overlaps(BitSeq(a), BitSeq(b)) +template incl*(a: var BitList, b: BitList) = incl(BitSeq(a), BitSeq(b)) +template isSubsetOf*(a, b: BitList): bool = isSubsetOf(BitSeq(a), BitSeq(b)) +template isZeros*(x: BitList): bool = isZeros(BitSeq(x)) +template countOnes*(x: BitList): int = countOnes(BitSeq(x)) +template countZeros*(x: BitList): int = countZeros(BitSeq(x)) +template countOverlap*(x, y: BitList): int = countOverlap(BitSeq(x), BitSeq(y)) +template `$`*(a: BitList): string = $(BitSeq(a)) + +iterator items*(x: BitList): bool = + for i in 0 ..< x.len: + yield x[i] + +macro unsupported*(T: typed): untyped = + # TODO: {.fatal.} breaks compilation even in `compiles()` context, + # so we use this macro instead. It's also much better at figuring + # out the actual type that was used in the instantiation. + # File both problems as issues. + error "SSZ serialization of the type " & humaneTypeName(T) & " is not supported" + +template ElemType*(T: type array): untyped = + type(default(T)[low(T)]) + +template ElemType*(T: type seq): untyped = + type(default(T)[0]) + +template ElemType*(T: type List): untyped = + T.T + +func isFixedSize*(T0: type): bool {.compileTime.} = + mixin toSszType, enumAllSerializedFields + + type T = type toSszType(declval T0) + + when T is BasicType: + return true + elif T is array: + return isFixedSize(ElemType(T)) + elif T is object|tuple: + enumAllSerializedFields(T): + when not isFixedSize(FieldType): + return false + return true + +func fixedPortionSize*(T0: type): int {.compileTime.} = + mixin enumAllSerializedFields, toSszType + + type T = type toSszType(declval T0) + + when T is BasicType: sizeof(T) + elif T is array: + type E = ElemType(T) + when isFixedSize(E): int(len(T)) * fixedPortionSize(E) + else: int(len(T)) * offsetSize + elif T is object|tuple: + enumAllSerializedFields(T): + when isFixedSize(FieldType): + result += fixedPortionSize(FieldType) + else: + result += offsetSize + else: + unsupported T0 + +# TODO This should have been an iterator, but the VM can't compile the +# code due to "too many registers required". +proc fieldInfos*(RecordType: type): seq[tuple[name: string, + offset: int, + fixedSize: int, + branchKey: string]] = + mixin enumAllSerializedFields + + var + offsetInBranch = {"": 0}.toTable + nestedUnder = initTable[string, string]() + + enumAllSerializedFields(RecordType): + const + isFixed = isFixedSize(FieldType) + fixedSize = when isFixed: fixedPortionSize(FieldType) + else: 0 + branchKey = when fieldCaseDiscriminator.len == 0: "" + else: fieldCaseDiscriminator & ":" & $fieldCaseBranches + fieldSize = when isFixed: fixedSize + else: offsetSize + + nestedUnder[fieldName] = branchKey + + var fieldOffset: int + offsetInBranch.withValue(branchKey, val): + fieldOffset = val[] + val[] += fieldSize + do: + try: + let parentBranch = nestedUnder.getOrDefault(fieldCaseDiscriminator, "") + fieldOffset = offsetInBranch[parentBranch] + offsetInBranch[branchKey] = fieldOffset + fieldSize + except KeyError as e: + raiseAssert e.msg + + result.add((fieldName, fieldOffset, fixedSize, branchKey)) + +func getFieldBoundingOffsetsImpl(RecordType: type, fieldName: static string): + tuple[fieldOffset, nextFieldOffset: int, isFirstOffset: bool] + {.compileTime.} = + result = (-1, -1, false) + var fieldBranchKey: string + var isFirstOffset = true + + for f in fieldInfos(RecordType): + if fieldName == f.name: + result[0] = f.offset + if f.fixedSize > 0: + result[1] = result[0] + f.fixedSize + return + else: + fieldBranchKey = f.branchKey + result.isFirstOffset = isFirstOffset + + elif result[0] != -1 and + f.fixedSize == 0 and + f.branchKey == fieldBranchKey: + # We have found the next variable sized field + result[1] = f.offset + return + + if f.fixedSize == 0: + isFirstOffset = false + +func getFieldBoundingOffsets*(RecordType: type, fieldName: static string): + tuple[fieldOffset, nextFieldOffset: int, isFirstOffset: bool] + {.compileTime.} = + ## Returns the start and end offsets of a field. + ## + ## For fixed-size fields, the start offset points to the first + ## byte of the field and the end offset points to 1 byte past the + ## end of the field. + ## + ## For variable-size fields, the returned offsets point to the + ## statically known positions of the 32-bit offset values written + ## within the SSZ object. You must read the 32-bit values stored + ## at the these locations in order to obtain the actual offsets. + ## + ## For variable-size fields, the end offset may be -1 when the + ## designated field is the last variable sized field within the + ## object. Then the SSZ object boundary known at run-time marks + ## the end of the variable-size field. + type T = RecordType + anonConst getFieldBoundingOffsetsImpl(T, fieldName) + +template enumerateSubFields*(holder, fieldVar, body: untyped) = + when holder is array: + for fieldVar in holder: body + else: + enumInstanceSerializedFields(holder, _{.used.}, fieldVar): body + +method formatMsg*( + err: ref SszSizeMismatchError, + filename: string): string {.gcsafe, raises: [Defect].} = + try: + &"SSZ size mismatch, element {err.elementSize}, actual {err.actualSszSize}, type {err.deserializedType}, file {filename}" + except CatchableError: + "SSZ size mismatch" diff --git a/tests/p2p/all_portal_tests.nim b/tests/p2p/all_portal_tests.nim new file mode 100644 index 0000000..e6993e7 --- /dev/null +++ b/tests/p2p/all_portal_tests.nim @@ -0,0 +1,5 @@ +{.used.} + +import + ./test_portal_encoding, + ./test_portal \ No newline at end of file diff --git a/tests/p2p/all_tests.nim b/tests/p2p/all_tests.nim index f73574e..f37823f 100644 --- a/tests/p2p/all_tests.nim +++ b/tests/p2p/all_tests.nim @@ -1,5 +1,6 @@ import ./all_discv5_tests, + ./all_portal_tests, ./test_auth, ./test_crypt, ./test_discovery, diff --git a/tests/p2p/test_discoveryv5.nim b/tests/p2p/test_discoveryv5.nim index 9f17bfe..b98b997 100644 --- a/tests/p2p/test_discoveryv5.nim +++ b/tests/p2p/test_discoveryv5.nim @@ -645,10 +645,13 @@ procSuite "Discovery v5 Tests": rng, PrivateKey.random(rng[]), localAddress(20303)) talkProtocol = "echo".toBytes() - proc handler(request: seq[byte]): seq[byte] {.gcsafe, raises: [Defect].} = + proc handler(protocol: TalkProtocol, request: seq[byte]): seq[byte] + {.gcsafe, raises: [Defect].} = request - check node2.registerTalkProtocol(talkProtocol, handler).isOk() + let echoProtocol = TalkProtocol(protocolHandler: handler) + + check node2.registerTalkProtocol(talkProtocol, echoProtocol).isOk() let talkresp = await discv5_protocol.talkreq(node1, node2.localNode, talkProtocol, "hello".toBytes()) @@ -667,13 +670,16 @@ procSuite "Discovery v5 Tests": rng, PrivateKey.random(rng[]), localAddress(20303)) talkProtocol = "echo".toBytes() - proc handler(request: seq[byte]): seq[byte] {.gcsafe, raises: [Defect].} = + proc handler(protocol: TalkProtocol, request: seq[byte]): seq[byte] + {.gcsafe, raises: [Defect].} = request + let echoProtocol = TalkProtocol(protocolHandler: handler) + check: - node2.registerTalkProtocol(talkProtocol, handler).isOk() - node2.registerTalkProtocol(talkProtocol, handler).isErr() - node2.registerTalkProtocol("test".toBytes(), handler).isOk() + node2.registerTalkProtocol(talkProtocol, echoProtocol).isOk() + node2.registerTalkProtocol(talkProtocol, echoProtocol).isErr() + node2.registerTalkProtocol("test".toBytes(), echoProtocol).isOk() await node1.closeWait() await node2.closeWait() diff --git a/tests/p2p/test_portal.nim b/tests/p2p/test_portal.nim new file mode 100644 index 0000000..adb10f4 --- /dev/null +++ b/tests/p2p/test_portal.nim @@ -0,0 +1,103 @@ +# nim-eth - Portal Network +# Copyright (c) 2021 Status Research & Development GmbH +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +{.used.} + +import + chronos, testutils/unittests, + ../../eth/keys, # for rng + ../../eth/p2p/discoveryv5/protocol as discv5_protocol, + ../../eth/p2p/portal/protocol as portal_protocol, + ./discv5_test_helper + +proc random(T: type UInt256, rng: var BrHmacDrbgContext): T = + var key: UInt256 + brHmacDrbgGenerate(addr rng, addr key, csize_t(sizeof(key))) + + key + +procSuite "Portal Tests": + let rng = newRng() + + asyncTest "Portal Ping/Pong": + let + node1 = initDiscoveryNode( + rng, PrivateKey.random(rng[]), localAddress(20302)) + node2 = initDiscoveryNode( + rng, PrivateKey.random(rng[]), localAddress(20303)) + + proto1 = PortalProtocol.new(node1) + proto2 = PortalProtocol.new(node2) + + let pong = await proto1.ping(proto2.baseProtocol.localNode) + + check: + pong.isOk() + pong.get().enrSeq == 1'u64 + pong.get().dataRadius == UInt256.high() + + await node1.closeWait() + await node2.closeWait() + + asyncTest "Portal FindNode/Nodes": + let + node1 = initDiscoveryNode( + rng, PrivateKey.random(rng[]), localAddress(20302)) + node2 = initDiscoveryNode( + rng, PrivateKey.random(rng[]), localAddress(20303)) + + proto1 = PortalProtocol.new(node1) + proto2 = PortalProtocol.new(node2) + + block: # Find itself + let nodes = await proto1.findNode(proto2.baseProtocol.localNode, + List[uint16, 256](@[0'u16])) + + check: + nodes.isOk() + nodes.get().total == 1'u8 + nodes.get().enrs.len() == 1 + + block: # Find nothing + let nodes = await proto1.findNode(proto2.baseProtocol.localNode, + List[uint16, 256](@[])) + + check: + nodes.isOk() + nodes.get().total == 1'u8 + nodes.get().enrs.len() == 0 + + block: # Find for distance + # TODO: Add test when implemented + discard + + await node1.closeWait() + await node2.closeWait() + + asyncTest "Portal FindContent/FoundContent": + let + node1 = initDiscoveryNode( + rng, PrivateKey.random(rng[]), localAddress(20302)) + node2 = initDiscoveryNode( + rng, PrivateKey.random(rng[]), localAddress(20303)) + + proto1 = PortalProtocol.new(node1) + proto2 = PortalProtocol.new(node2) + + let contentKey = ByteList(@(UInt256.random(rng[]).toBytes())) + + let foundContent = await proto1.findContent(proto2.baseProtocol.localNode, + contentKey) + + check: + foundContent.isOk() + # TODO: adjust when implemented + foundContent.get().enrs.len() == 0 + foundContent.get().payload.len() == 0 + + await node1.closeWait() + await node2.closeWait() diff --git a/tests/p2p/test_portal_encoding.nim b/tests/p2p/test_portal_encoding.nim new file mode 100644 index 0000000..e30bc01 --- /dev/null +++ b/tests/p2p/test_portal_encoding.nim @@ -0,0 +1,156 @@ +# nim-eth - Portal Network +# Copyright (c) 2021 Status Research & Development GmbH +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +{.used.} + +import + std/unittest, + stint, stew/[byteutils, results], + ../../eth/p2p/portal/messages + +suite "Portal Protocol Message Encodings": + test "Ping Request": + var dataRadius: UInt256 + let + enrSeq = 1'u64 + p = PingMessage(enrSeq: enrSeq, dataRadius: dataRadius) + + let encoded = encodeMessage(p) + check encoded.toHex == + "0101000000000000000000000000000000000000000000000000000000000000000000000000000000" + let decoded = decodeMessage(encoded) + check decoded.isOk() + + let message = decoded.get() + check: + message.kind == ping + message.ping.enrSeq == enrSeq + message.ping.dataRadius == dataRadius + + test "Pong Response": + var dataRadius: UInt256 + let + enrSeq = 1'u64 + p = PongMessage(enrSeq: enrSeq, dataRadius: dataRadius) + + let encoded = encodeMessage(p) + check encoded.toHex == + "0201000000000000000000000000000000000000000000000000000000000000000000000000000000" + let decoded = decodeMessage(encoded) + check decoded.isOk() + + let message = decoded.get() + check: + message.kind == pong + message.pong.enrSeq == enrSeq + message.pong.dataRadius == dataRadius + + test "FindNode Request": + let + distances = List[uint16, 256](@[0x0100'u16]) + fn = FindNodeMessage(distances: distances) + + let encoded = encodeMessage(fn) + check encoded.toHex == "03040000000001" + + let decoded = decodeMessage(encoded) + check decoded.isOk() + + let message = decoded.get() + check: + message.kind == findnode + message.findnode.distances == distances + + test "Nodes Response (empty)": + let + total = 0x1'u8 + n = NodesMessage(total: total) + + let encoded = encodeMessage(n) + check encoded.toHex == "040105000000" + + let decoded = decodeMessage(encoded) + check decoded.isOk() + + let message = decoded.get() + check: + message.kind == nodes + message.nodes.total == total + message.nodes.enrs.len() == 0 + + test "FindContent Request": + let + contentKey = ByteList(@[byte 0x01, 0x02, 0x03]) + fn = FindContentMessage(contentKey: contentKey) + + let encoded = encodeMessage(fn) + check encoded.toHex == "0504000000010203" + + let decoded = decodeMessage(encoded) + check decoded.isOk() + + let message = decoded.get() + check: + message.kind == findcontent + message.findcontent.contentKey == contentKey + + test "FoundContent Response (empty enrs)": + let + enrs = List[ByteList, 32](@[]) + payload = ByteList(@[byte 0x01, 0x02, 0x03]) + n = FoundContentMessage(enrs: enrs, payload: payload) + + let encoded = encodeMessage(n) + check encoded.toHex == "060800000008000000010203" + + let decoded = decodeMessage(encoded) + check decoded.isOk() + + let message = decoded.get() + check: + message.kind == foundcontent + message.foundcontent.enrs.len() == 0 + message.foundcontent.payload == payload + + test "Advertise Request": + let + contentKeys = List[ByteList, 32](List(@[ByteList(@[byte 0x01, 0x02, 0x03])])) + am = AdvertiseMessage(contentKeys) + # am = AdvertiseMessage(contentKeys: contentKeys) + + let encoded = encodeMessage(am) + check encoded.toHex == "0704000000010203" + # "070400000004000000010203" + + let decoded = decodeMessage(encoded) + check decoded.isOk() + + let message = decoded.get() + check: + message.kind == advertise + message.advertise == contentKeys + # message.advertise.contentKeys == contentKeys + + test "RequestProofs Response": # That sounds weird + let + connectionId = List[byte, 4](@[byte 0x01, 0x02, 0x03, 0x04]) + contentKeys = + List[ByteList, 32](List(@[ByteList(@[byte 0x01, 0x02, 0x03])])) + n = RequestProofsMessage(connectionId: connectionId, + contentKeys: contentKeys) + + let encoded = encodeMessage(n) + check encoded.toHex == "08080000000c0000000102030404000000010203" + + let decoded = decodeMessage(encoded) + check decoded.isOk() + + let message = decoded.get() + check: + message.kind == requestproofs + message.requestproofs.connectionId == connectionId + message.requestproofs.contentKeys == contentKeys