diff --git a/README.md b/README.md index acef5fc..def60dc 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,8 @@ the `EthereumNode` type: ``` nim proc newEthereumNode*(keys: KeyPair, + listeningAddress: Address, + networkId: uint, chain: AbstractChainDB, clientId = "nim-eth-p2p", addAllCapabilities = true): EthereumNode = @@ -38,6 +40,14 @@ proc newEthereumNode*(keys: KeyPair, library for utilities that will help you generate and manage such keys. +`listeningAddress`: + The network interface and port where your client will be + accepting incoming connections. + +`networkId`: + The Ethereum network ID. The client will disconnect immediately + from any peers who don't use the same network. + `chain`: An abstract instance of the Ethereum blockchain associated with the node. This library allows you to plug any instance @@ -60,7 +70,7 @@ proc newEthereumNode*(keys: KeyPair, node.addCapability(eth) node.addCapability(ssh) ``` - + Each supplied protocol identifier is a name of a protocol introduced by the `rlpxProtocol` macro discussed later in this document. @@ -69,16 +79,14 @@ the network. To start the connection process, call `node.connectToNetwork`: ``` nim proc connectToNetwork*(node: var EthereumNode, - address: Address, - listeningPort = Port(30303), bootstrapNodes: openarray[ENode], - networkId: int, - startListening = true) + startListening = true, + enableDiscovery = true) ``` The `EthereumNode` will automatically find and maintan a pool of peers using the Ethereum node discovery protocol. You can access the pool as -`node.peers`. +`node.peers`. ## Communicating with Peers using RLPx @@ -106,7 +114,7 @@ a 3-letter identifier for the protocol and the current protocol version: Here is how the [DevP2P wire protocol](https://github.com/ethereum/wiki/wiki/%C3%90%CE%9EVp2p-Wire-Protocol) might look like: ``` nim -rlpxProtocol p2p, 0: +rlpxProtocol p2p(version = 0): proc hello(peer: Peer, version: uint, clientId: string, @@ -130,25 +138,28 @@ and the asynchronous code responsible for handling the incoming messages. ### Protocol state -The protocol implementations are expected to maintain a state and to act like -a state machine handling the incoming messages. To achieve this, each protocol -may define a `State` object that can be accessed as a `state` field of the `Peer` -object: +The protocol implementations are expected to maintain a state and to act +like a state machine handling the incoming messages. You are allowed to +define an arbitrary state type that can be specified in the `peerState` +protocol option. Later, instances of the state object can be obtained +though the `state` pseudo-field of the `Peer` object: ``` nim -rlpxProtocol abc, 1: - type State = object - receivedMsgsCount: int +type AbcPeerState = object + receivedMsgsCount: int + +rlpxProtocol abc(version = 1, + peerState = AbcPeerState): proc incomingMessage(p: Peer) = p.state.receivedMsgsCount += 1 ``` -Besides the per-peer state demonstrated above, there is also support for -maintaining a network-wide state. In the example above, we'll just have -to change the name of the state type to `NetworkState` and the accessor -expression to `p.network.state`. +Besides the per-peer state demonstrated above, there is also support +for maintaining a network-wide state. It's enabled by specifying the +`networkState` option of the protocol and the state object can be obtained +through accessor of the same name. The state objects are initialized to zero by default, but you can modify this behaviour by overriding the following procs for your state types: @@ -158,11 +169,8 @@ proc initProtocolState*(state: var MyPeerState, p: Peer) proc initProtocolState*(state: var MyNetworkState, n: EthereumNode) ``` -Please note that the state type will have to be placed outside of the -protocol definition in order to achieve this. - -Sometimes, you'll need to access the state of another protocol. To do this, -specify the protocol identifier to the `state` accessors: +Sometimes, you'll need to access the state of another protocol. +To do this, specify the protocol identifier to the `state` accessors: ``` nim echo "ABC protocol messages: ", peer.state(abc).receivedMsgCount @@ -218,7 +226,7 @@ rlpxProtocol les, 2: requestResponse: proc getProofs(p: Peer, proofs: openarray[ProofRequest]) proc proofs(p: Peer, BV: uint, proofs: openarray[Blob]) - + ... ``` @@ -234,16 +242,15 @@ be specified for each individual call and the default value can be overridden on the level of individual message, or the entire protocol: ``` nim -rlpxProtocol abc, 1: - timeout = 5000 # value in milliseconds - useRequestIds = false - +rlpxProtocol abc(version = 1, + useRequestIds = false, + timeout = 5000): # value in milliseconds requestResponse: proc myReq(dataId: int, timeout = 3000) proc myRes(data: string) ``` -By default, the library will take care of inserting a hidden `reqId` +By default, the library will take care of inserting a hidden `reqId` parameter as used in the [LES protocol](https://github.com/zsfelfoldi/go-ethereum/wiki/Light-Ethereum-Subprotocol-%28LES%29), but you can disable this behavior by overriding the protocol setting `useRequestIds`. @@ -255,7 +262,7 @@ also include handlers for certain important events such as newly connected peers or misbehaving or disconnecting peers: ``` nim -rlpxProtocol les, 2: +rlpxProtocol les(version = 2): onPeerConnected do (peer: Peer): asyncCheck peer.status [ "networkId": rlp.encode(1), diff --git a/eth_p2p.nim b/eth_p2p.nim index dee8bee..ddec4fd 100644 --- a/eth_p2p.nim +++ b/eth_p2p.nim @@ -9,30 +9,31 @@ # import - tables, deques, macros, sets, algorithm, hashes, times, - random, options, sequtils, typetraits, os, - asyncdispatch2, asyncdispatch2/timer, - rlp, ranges/[stackarrays, ptr_arith], nimcrypto, chronicles, - eth_keys, eth_common, - eth_p2p/[kademlia, discovery, auth, rlpxcrypt, enode] + tables, algorithm, random, + asyncdispatch2, asyncdispatch2/timer, chronicles, + eth_keys, eth_common/eth_types, + eth_p2p/[kademlia, discovery, enode, peer_pool, rlpx], + eth_p2p/private/types + +types.forwardPublicTypes export - enode, kademlia, options + rlpx, enode, kademlia -proc addProtocol(n: var EthereumNode, p: ProtocolInfo) = +proc addCapability*(n: var EthereumNode, p: ProtocolInfo) = assert n.connectionState == ConnectionState.None - let pos = lowerBound(n.rlpxProtocols, p) + let pos = lowerBound(n.rlpxProtocols, p, rlpx.cmp) n.rlpxProtocols.insert(p, pos) - n.rlpxCapabilities.insert(Capability(name: p.name, version: p.version), pos) + n.rlpxCapabilities.insert(p.asCapability, pos) template addCapability*(n: var EthereumNode, Protocol: type) = - addProtocol(n, Protocol.protocolInfo) + addCapability(n, Protocol.protocolInfo) proc newEthereumNode*(keys: KeyPair, address: Address, networkId: uint, chain: AbstractChainDB, - clientId = clientId, + clientId = "nim-eth-p2p/0.2.0", # TODO: read this value from nimble somehow addAllCapabilities = true): EthereumNode = new result result.keys = keys @@ -45,7 +46,7 @@ proc newEthereumNode*(keys: KeyPair, if addAllCapabilities: for p in rlpxProtocols: - result.addProtocol(p) + result.addCapability(p) proc processIncoming(server: StreamServer, remote: StreamTransport): Future[void] {.async, gcsafe.} = @@ -69,6 +70,13 @@ proc startListening*(node: EthereumNode) = udata = cast[pointer](node)) node.listeningServer.start() +proc initProtocolStates*(node: EthereumNode) = + # TODO: This should be moved to a private module + node.protocolStates.newSeq(rlpxProtocols.len) + for p in node.rlpxProtocols: + if p.networkStateInitializer != nil: + node.protocolStates[p.index] = ((p.networkStateInitializer)(node)) + proc connectToNetwork*(node: EthereumNode, bootstrapNodes: seq[ENode], startListening = true, @@ -80,17 +88,14 @@ proc connectToNetwork*(node: EthereumNode, node.address, bootstrapNodes) - node.peerPool = newPeerPool(node, node.chain, node.networkId, + node.peerPool = newPeerPool(node, node.networkId, node.keys, node.discovery, node.clientId, node.address.tcpPort) if startListening: eth_p2p.startListening(node) - node.protocolStates.newSeq(rlpxProtocols.len) - for p in node.rlpxProtocols: - if p.networkStateInitializer != nil: - node.protocolStates[p.index] = p.networkStateInitializer(node) + node.initProtocolStates() if startListening: node.listeningServer.start() diff --git a/eth_p2p.nimble b/eth_p2p.nimble index d8b7006..859671f 100644 --- a/eth_p2p.nimble +++ b/eth_p2p.nimble @@ -17,14 +17,15 @@ requires "nim > 0.18.0", "byteutils", "chronicles", "asyncdispatch2", - "eth_common" + "eth_common", + "package_visible_types" -proc runTest(name: string, lang = "c") = exec "nim " & lang & " --experimental:ForLoopMacros -r tests/" & name +proc runTest(name: string, lang = "c") = + exec "nim " & lang & " -d:testing --experimental:ForLoopMacros -r tests/" & name task test, "Runs the test suite": - runTest "testecies" - runTest "testauth" - runTest "testcrypt" runTest "testenode" runTest "tdiscovery" runTest "tserver" + runTest "all_tests" + diff --git a/eth_p2p/blockchain_sync.nim b/eth_p2p/blockchain_sync.nim index 1585d31..a7d7d86 100644 --- a/eth_p2p/blockchain_sync.nim +++ b/eth_p2p/blockchain_sync.nim @@ -1,3 +1,13 @@ +import + sets, options, random, hashes, + asyncdispatch2, chronicles, eth_common/eth_types, + private/types, rlpx, peer_pool, rlpx_protocols/eth_protocol, + ../eth_p2p.nim + +const + minPeersToStartSync* = 2 # Wait for consensus of at least this + # number of peers before syncing + type SyncStatus* = enum syncSuccess @@ -26,6 +36,8 @@ type trustedPeers: HashSet[Peer] hasOutOfOrderBlocks: bool +proc hash*(p: Peer): Hash {.inline.} = hash(cast[pointer](p)) + proc endIndex(b: WantedBlocks): BlockNumber = result = b.startIndex result += (b.numBlocks - 1).u256 @@ -228,6 +240,7 @@ proc randomTrustedPeer(ctx: SyncContext): Peer = inc i proc startSyncWithPeer(ctx: SyncContext, peer: Peer) {.async.} = + debug "start sync ", peer, trustedPeers = ctx.trustedPeers.len if ctx.trustedPeers.len >= minPeersToStartSync: # We have enough trusted peers. Validate new peer against trusted if await peersAgreeOnChain(peer, ctx.randomTrustedPeer()): @@ -280,7 +293,7 @@ proc onPeerConnected(ctx: SyncContext, peer: Peer) = error "startSyncWithPeer failed", msg = f.readError.msg, peer proc onPeerDisconnected(ctx: SyncContext, p: Peer) = - echo "onPeerDisconnected" + debug "peer disconnected ", peer = p ctx.trustedPeers.excl(p) proc startSync(ctx: SyncContext) = diff --git a/eth_p2p/blockchain_utils.nim b/eth_p2p/blockchain_utils.nim new file mode 100644 index 0000000..3859c2e --- /dev/null +++ b/eth_p2p/blockchain_utils.nim @@ -0,0 +1,41 @@ +import + eth_common/[eth_types, state_accessors] + +# TODO: Perhaps we can move this to eth-common + +proc getBlockHeaders*(db: AbstractChainDb, + req: BlocksRequest): seq[BlockHeader] = + result = newSeqOfCap[BlockHeader](req.maxResults) + + var foundBlock: BlockHeader + if db.getBlockHeader(req.startBlock, foundBlock): + result.add foundBlock + + while uint64(result.len) < req.maxResults: + if not db.getSuccessorHeader(foundBlock, foundBlock): + break + result.add foundBlock + +template fetcher*(fetcherName, fetchingFunc, InputType, ResultType: untyped) = + proc fetcherName*(db: AbstractChainDb, + lookups: openarray[InputType]): seq[ResultType] = + for lookup in lookups: + let fetched = fetchingFunc(db, lookup) + if fetched.hasData: + # TODO: should there be an else clause here. + # Is the peer responsible of figuring out that + # some of the requested items were not found? + result.add deref(fetched) + +fetcher getContractCodes, getContractCode, ContractCodeRequest, Blob +fetcher getBlockBodies, getBlockBody, KeccakHash, BlockBody +fetcher getStorageNodes, getStorageNode, KeccakHash, Blob +fetcher getReceipts, getReceipt, KeccakHash, Receipt +fetcher getProofs, getProof, ProofRequest, Blob +fetcher getHeaderProofs, getHeaderProof, ProofRequest, Blob + +proc getHelperTrieProofs*(db: AbstractChainDb, + reqs: openarray[HelperTrieProofRequest], + outNodes: var seq[Blob], outAuxData: var seq[Blob]) = + discard + diff --git a/eth_p2p/kademlia.nim b/eth_p2p/kademlia.nim index 9aeb1d7..465c234 100644 --- a/eth_p2p/kademlia.nim +++ b/eth_p2p/kademlia.nim @@ -46,7 +46,7 @@ const FIND_CONCURRENCY = 3 # parallel find node lookups ID_SIZE = 256 -proc toNodeId(pk: PublicKey): NodeId = +proc toNodeId*(pk: PublicKey): NodeId = readUintBE[256](keccak256.digest(pk.getRaw()).data) proc newNode*(pk: PublicKey, address: Address): Node = @@ -67,8 +67,10 @@ proc newNode*(enode: ENode): Node = proc distanceTo(n: Node, id: NodeId): UInt256 = n.id xor id proc `$`*(n: Node): string = - # "Node[" & $n.node & "]" - "Node[" & $n.node.address.ip & ":" & $n.node.address.udpPort & "]" + if n == nil: + "Node[local]" + else: + "Node[" & $n.node.address.ip & ":" & $n.node.address.udpPort & "]" proc hash*(n: Node): hashes.Hash = hash(n.node.pubkey.data) proc `==`*(a, b: Node): bool = a.node.pubkey == b.node.pubkey diff --git a/eth_p2p/mock_peers.nim b/eth_p2p/mock_peers.nim new file mode 100644 index 0000000..4192ef9 --- /dev/null +++ b/eth_p2p/mock_peers.nim @@ -0,0 +1,210 @@ +import + macros, deques, algorithm, + asyncdispatch2, eth_keys, rlp, eth_common/eth_types, + private/types, rlpx, ../eth_p2p + +type + Action = proc (p: Peer, data: Rlp): Future[void] + + ProtocolMessagePair = object + protocol: ProtocolInfo + id: int + + ExpectedMsg = object + msg: ProtocolMessagePair + response: Action + + MockConf* = ref object + keys*: KeyPair + address*: Address + networkId*: uint + chain*: AbstractChainDb + clientId*: string + waitForHello*: bool + + devp2pHandshake: ExpectedMsg + handshakes: seq[ExpectedMsg] + protocols: seq[ProtocolInfo] + + expectedMsgs: Deque[ExpectedMsg] + receivedMsgsCount: int + +var + nextUnusedMockPort = 40304 + +proc toAction(a: Action): Action = a + +proc toAction[N](actions: array[N, Action]): Action = + mixin await + result = proc (peer: Peer, data: Rlp) {.async.} = + for a in actions: + await a(peer, data) + +proc toAction(a: proc (): Future[void]): Action = + result = proc (peer: Peer, data: Rlp) {.async.} = + await a() + +proc toAction(a: proc (peer: Peer): Future[void]): Action = + result = proc (peer: Peer, data: Rlp) {.async.} = + await a(peer) + +proc delay*(duration: int): Action = + result = proc (p: Peer, data: Rlp) {.async.} = + await sleepAsync(duration) + +proc reply(bytes: Bytes): Action = + result = proc (p: Peer, data: Rlp) {.async.} = + await p.sendMsg(bytes) + +proc reply*[Msg](msg: Msg): Action = + mixin await + result = proc (p: Peer, data: Rlp) {.async.} = + await p.send(msg) + +proc localhostAddress*(port: int): Address = + let port = Port(port) + result = Address(udpPort: port, tcpPort: port, ip: parseIpAddress("127.0.0.1")) + +proc makeProtoMsgPair(MsgType: type): ProtocolMessagePair = + mixin msgProtocol, protocolInfo + result.protocol = MsgType.msgProtocol.protocolInfo + result.id = MsgType.msgId + +proc readReqId*(rlp: Rlp): int = + var r = rlp + return r.read(int) + +proc expectationViolationMsg(mock: MockConf, + reason: string, + receivedMsg: ptr MessageInfo): string = + result = "[Mock expectation violated] " & reason & ": " & receivedMsg.name + for i in 0 ..< mock.expectedMsgs.len: + let expected = mock.expectedMsgs[i].msg + result.add "\n " & expected.protocol.messages[expected.id].name + if i == mock.receivedMsgsCount: result.add " <- we are here" + result.add "\n" + +proc addProtocol(mock: MockConf, p: ProtocolInfo): ProtocolInfo = + new result + deepCopy(result[], p[]) + + proc incomingMsgHandler(p: Peer, receivedMsgId: int, rlp: Rlp): Future[void] = + let (receivedMsgProto, receivedMsgInfo) = p.getMsgMetadata(receivedMsgId) + let expectedMsgIdx = mock.receivedMsgsCount + + template fail(reason: string) = + stdout.write mock.expectationViolationMsg(reason, receivedMsgInfo) + quit 1 + + if expectedMsgIdx > mock.expectedMsgs.len: + fail "Mock peer received more messages than expected" + + let expectedMsg = mock.expectedMsgs[expectedMsgIdx] + if receivedMsgInfo.id != expectedMsg.msg.id or + receivedMsgProto.name != expectedMsg.msg.protocol.name: + fail "Mock peer received an unexpected message" + + inc mock.receivedMsgsCount + if expectedMsg.response != nil: + return expectedMsg.response(p, rlp) + else: + result = newFuture[void]() + result.complete() + + for m in mitems(result.messages): + m.thunk = incomingMsgHandler + + result.handshake = nil + + # TODO This mock conf can override this + result.disconnectHandler = nil + + mock.protocols.add result + +proc addHandshake*(mock: MockConf, msg: auto) = + var msgInfo = makeProtoMsgPair(msg.type) + msgInfo.protocol = mock.addProtocol(msgInfo.protocol) + let expectedMsg = ExpectedMsg(msg: msgInfo, response: reply(msg)) + + when msg is p2p.hello: + devp2pHandshake = expectedMsg + else: + mock.handshakes.add expectedMsg + +proc addCapability*(mock: MockConf, Protocol: type) = + mixin defaultTestingHandshake, protocolInfo + + when compiles(defaultTestingHandshake(Protocol)): + mock.addHandshake(defaultTestingHandshake(Protocol)) + else: + discard mock.addProtocol(Protocol.protocolInfo) + +proc expectImpl(mock: MockConf, msg: ProtocolMessagePair, action: Action) = + mock.expectedMsgs.addLast ExpectedMsg(msg: msg, response: action) + +macro expect*(mock: MockConf, MsgType: type, handler: untyped = nil): untyped = + if handler.kind in {nnkLambda, nnkDo}: + handler.addPragma ident("async") + + result = newCall( + bindSym("expectImpl"), + mock, + newCall(bindSym"makeProtoMsgPair", MsgType.getType), + newCall(bindSym"toAction", handler)) + +proc newMockPeer*(userConfigurator: proc (m: MockConf)): EthereumNode = + var mockConf = new MockConf + mockConf.keys = newKeyPair() + mockConf.address = localhostAddress(nextUnusedMockPort) + inc nextUnusedMockPort + mockConf.networkId = 1'u + mockConf.clientId = "Mock Peer" + mockConf.waitForHello = true + mockConf.expectedMsgs = initDeque[ExpectedMsg]() + + userConfigurator(mockConf) + + var node = newEthereumNode(mockConf.keys, + mockConf.address, + mockConf.networkId, + mockConf.chain, + mockConf.clientId, + addAllCapabilities = false) + + mockConf.handshakes.sort do (lhs, rhs: ExpectedMsg) -> int: + # this is intentially sorted in reverse order, so we + # can add them in the correct order below. + return -cmp(lhs.msg.protocol.index, rhs.msg.protocol.index) + + for h in mockConf.handshakes: + mockConf.expectedMsgs.addFirst h + + for p in mockConf.protocols: + node.addCapability p + + when false: + # TODO: This part doesn't work correctly yet. + # rlpx{Connect,Accept} control the handshake. + if mockConf.devp2pHandshake.response != nil: + mockConf.expectedMsgs.addFirst mockConf.devp2pHandshake + else: + proc sendHello(p: Peer, data: Rlp) {.async.} = + await p.hello(devp2pVersion, + mockConf.clientId, + node.rlpxCapabilities, + uint(node.address.tcpPort), + node.keys.pubkey.getRaw()) + + mockConf.expectedMsgs.addFirst ExpectedMsg( + msg: makeProtoMsgPair(p2p.hello), + response: sendHello) + + node.initProtocolStates() + node.startListening() + return node + +proc rlpxConnect*(node, otherNode: EthereumNode): Future[Peer] = + let otherAsRemote = newNode(initENode(otherNode.keys.pubKey, + otherNode.address)) + return rlpx.rlpxConnect(node, otherAsRemote) + diff --git a/eth_p2p/peer_pool.nim b/eth_p2p/peer_pool.nim index d5eb295..9c5924e 100644 --- a/eth_p2p/peer_pool.nim +++ b/eth_p2p/peer_pool.nim @@ -1,12 +1,17 @@ # PeerPool attempts to keep connections to at least min_peers # on the given network. +import + os, tables, times, random, + asyncdispatch2, chronicles, rlp, eth_keys, + private/types, discovery, kademlia, rlpx + const lookupInterval = 5 connectLoopSleepMs = 2000 proc newPeerPool*(network: EthereumNode, - chainDb: AbstractChainDB, networkId: uint, keyPair: KeyPair, + networkId: uint, keyPair: KeyPair, discovery: DiscoveryProtocol, clientId: string, listenPort = Port(30303), minPeers = 10): PeerPool = new result @@ -72,7 +77,7 @@ proc connect(p: PeerPool, remote: Node): Future[Peer] {.async.} = # try: # self.logger.debug("Connecting to %s...", remote) # peer = await wait_with_token( - # handshake(remote, self.privkey, self.peer_class, self.chaindb, self.network_id), + # handshake(remote, self.privkey, self.peer_class, self.network_id), # token=self.cancel_token, # timeout=HANDSHAKE_TIMEOUT) # return peer @@ -97,40 +102,10 @@ proc lookupRandomNode(p: PeerPool) {.async.} = proc getRandomBootnode(p: PeerPool): seq[Node] = @[p.discovery.bootstrapNodes.rand()] -proc peerFinished(p: PeerPool, peer: Peer) = - ## Remove the given peer from our list of connected nodes. - ## This is passed as a callback to be called when a peer finishes. - p.connectedNodes.del(peer.remote) - - for o in p.observers.values: - if not o.onPeerDisconnected.isNil: - o.onPeerDisconnected(peer) - -proc run(peer: Peer, peerPool: PeerPool) {.async.} = - # TODO: This is a stub that should be implemented in rlpx.nim - - try: - while true: - var (nextMsgId, nextMsgData) = await peer.recvMsg() - if nextMsgId == 1: - debug "Run got disconnect msg", reason = nextMsgData.listElem(0).toInt(uint32).DisconnectionReason, peer - break - else: - # debug "Got msg: ", msg = nextMsgId - await peer.dispatchMsg(nextMsgId, nextMsgData) - except: - error "Failed to read from peer", - err = getCurrentExceptionMsg(), - stackTrace = getCurrentException().getStackTrace() - - peerPool.peerFinished(peer) - proc connectToNode*(p: PeerPool, n: Node) {.async.} = let peer = await p.connect(n) if not peer.isNil: info "Connection established", peer - ensureFuture peer.run(p) - p.connectedNodes[peer.remote] = peer for o in p.observers.values: if not o.onPeerConnected.isNil: diff --git a/eth_p2p/private/types.nim b/eth_p2p/private/types.nim index b88c541..c4812de 100644 --- a/eth_p2p/private/types.nim +++ b/eth_p2p/private/types.nim @@ -1,4 +1,10 @@ -block: +import + deques, tables, + package_visible_types, + rlp, asyncdispatch2, eth_common/eth_types, eth_keys, + ../enode, ../kademlia, ../discovery, ../options, ../rlpxcrypt + +packageTypes: type EthereumNode* = ref object networkId*: uint @@ -15,9 +21,9 @@ block: peerPool*: PeerPool Peer* = ref object - transp: StreamTransport + transport: StreamTransport dispatcher: Dispatcher - nextReqId: int + lastReqId*: int network*: EthereumNode secretsState: SecretState connectionState: ConnectionState @@ -27,7 +33,7 @@ block: awaitedMessages: seq[FutureBase] OutstandingRequest = object - reqId: int + id: int future: FutureBase timeoutAt: uint64 @@ -85,12 +91,14 @@ block: # protocolOffsets: seq[int] messages: seq[ptr MessageInfo] + activeProtocols: seq[ProtocolInfo] PeerObserver* = object onPeerConnected*: proc(p: Peer) onPeerDisconnected*: proc(p: Peer) - MessageHandler = proc(x: Peer, data: Rlp): Future[void] + MessageHandlerDecorator = proc(msgId: int, n: NimNode): NimNode + MessageHandler = proc(x: Peer, msgId: int, data: Rlp): Future[void] MessageContentPrinter = proc(msg: pointer): string RequestResolver = proc(msg: pointer, future: FutureBase) NextMsgResolver = proc(msgData: Rlp, future: FutureBase) @@ -98,7 +106,7 @@ block: NetworkStateInitializer = proc(network: EthereumNode): RootRef HandshakeStep = proc(peer: Peer): Future[void] DisconnectionHandler = proc(peer: Peer, - reason: DisconnectionReason): Future[void] + reason: DisconnectionReason): Future[void] {.gcsafe.} RlpxMessageKind* = enum rlpxNotification, @@ -133,9 +141,8 @@ block: MalformedMessageError* = object of Exception - UnexpectedDisconnectError* = object of Exception + PeerDisconnected* = object of Exception reason*: DisconnectionReason UselessPeerError* = object of Exception - diff --git a/eth_p2p/rlpx.nim b/eth_p2p/rlpx.nim index e547a16..8845a85 100644 --- a/eth_p2p/rlpx.nim +++ b/eth_p2p/rlpx.nim @@ -1,11 +1,15 @@ +import + macros, tables, algorithm, deques, hashes, options, typetraits, + chronicles, nimcrypto, asyncdispatch2, rlp, eth_common, eth_keys, + private/types, kademlia, auth, rlpxcrypt, enode + logScope: - topic = "rlpx" + topics = "rlpx" const - baseProtocolVersion = 4 - clientId = "nim-eth-p2p/0.2.0" - + devp2pVersion* = 4 defaultReqTimeout = 10000 + maxMsgSize = 1024 * 1024 var gProtocols: seq[ProtocolInfo] @@ -14,25 +18,37 @@ var # The variables above are immutable RTTI information. We need to tell # Nim to not consider them GcSafe violations: -template rlpxProtocols: auto = {.gcsafe.}: gProtocols +template rlpxProtocols*: auto = {.gcsafe.}: gProtocols template devp2pProtocolInfo: auto = {.gcsafe.}: devp2p -# Dispatcher -# +proc newFuture[T](location: var Future[T]) = + location = newFuture[T]() proc `$`*(p: Peer): string {.inline.} = $p.remote +proc disconnect*(peer: Peer, reason: DisconnectionReason) {.async.} + +template raisePeerDisconnected(msg: string, r: DisconnectionReason) = + var e = newException(PeerDisconnected, msg) + e.reason = r + raise e + +proc disconnectAndRaise(peer: Peer, + reason: DisconnectionReason, + msg: string) {.async.} = + let r = reason + await peer.disconnect(r) + raisePeerDisconnected(msg, r) + +# Dispatcher +# + proc hash(d: Dispatcher): int = hash(d.protocolOffsets) proc `==`(lhs, rhs: Dispatcher): bool = - lhs.protocolOffsets == rhs.protocolOffsets - -iterator activeProtocols(d: Dispatcher): ProtocolInfo = - for i in 0 ..< rlpxProtocols.len: - if d.protocolOffsets[i] != -1: - yield rlpxProtocols[i] + lhs.activeProtocols == rhs.activeProtocols proc describeProtocols(d: Dispatcher): string = result = "" @@ -41,7 +57,7 @@ proc describeProtocols(d: Dispatcher): string = for c in protocol.name: result.add(c) proc numProtocols(d: Dispatcher): int = - for _ in d.activeProtocols: inc result + d.activeProtocols.len proc getDispatcher(node: EthereumNode, otherPeerCapabilities: openarray[Capability]): Dispatcher = @@ -51,25 +67,19 @@ proc getDispatcher(node: EthereumNode, new(result) newSeq(result.protocolOffsets, rlpxProtocols.len) + result.protocolOffsets.fill -1 var nextUserMsgId = 0x10 - for i in 0 ..< rlpxProtocols.len: - let localProtocol = rlpxProtocols[i] - if not node.rlpxProtocols.contains(localProtocol): - result.protocolOffsets[i] = -1 - continue - + for localProtocol in node.rlpxProtocols: + let idx = localProtocol.index block findMatchingProtocol: for remoteCapability in otherPeerCapabilities: if localProtocol.name == remoteCapability.name and localProtocol.version == remoteCapability.version: - result.protocolOffsets[i] = nextUserMsgId + result.protocolOffsets[idx] = nextUserMsgId nextUserMsgId += localProtocol.messages.len break findMatchingProtocol - # the local protocol is not supported by the other peer - # indicate this by a -1 offset: - result.protocolOffsets[i] = -1 if result in gDispatchers: return gDispatchers[result] @@ -81,13 +91,40 @@ proc getDispatcher(node: EthereumNode, result.messages = newSeq[ptr MessageInfo](nextUserMsgId) devp2pProtocolInfo.messages.copyTo(result.messages, 0) - for i in 0 ..< rlpxProtocols.len: - if result.protocolOffsets[i] != -1: - rlpxProtocols[i].messages.copyTo(result.messages, - result.protocolOffsets[i]) + for localProtocol in node.rlpxProtocols: + let idx = localProtocol.index + if result.protocolOffsets[idx] != -1: + result.activeProtocols.add localProtocol + localProtocol.messages.copyTo(result.messages, + result.protocolOffsets[idx]) gDispatchers.incl result +proc getMsgName*(peer: Peer, msgId: int): string = + if not peer.dispatcher.isNil and + msgId < peer.dispatcher.messages.len: + return peer.dispatcher.messages[msgId].name + else: + return case msgId + of 0: "hello" + of 1: "disconnect" + of 2: "ping" + of 3: "pong" + else: $msgId + +proc getMsgMetadata*(peer: Peer, msgId: int): (ProtocolInfo, ptr MessageInfo) = + doAssert msgId >= 0 + + if msgId <= devp2p.messages[^1].id: + return (devp2p, addr devp2p.messages[msgId]) + + if msgId < peer.dispatcher.messages.len: + for i in 0 ..< rlpxProtocols.len: + let offset = peer.dispatcher.protocolOffsets[i] + if offset != -1 and + offset + rlpxProtocols[i].messages[^1].id >= msgId: + return (rlpxProtocols[i], peer.dispatcher.messages[msgId]) + # Protocol info objects # @@ -109,11 +146,17 @@ proc setEventHandlers(p: ProtocolInfo, p.handshake = handshake p.disconnectHandler = disconnectHandler -proc nameStr*(p: ProtocolInfo): string = +func asCapability*(p: ProtocolInfo): Capability = + result.name = p.name + result.version = p.version + +func nameStr*(p: ProtocolInfo): string = result = newStringOfCap(3) for c in p.name: result.add(c) -proc cmp*(lhs, rhs: ProtocolInfo): int {.inline.} = +# XXX: this used to be inline, but inline procs +# cannot be passed to closure params +proc cmp*(lhs, rhs: ProtocolInfo): int = for i in 0..2: if lhs.name[i] != rhs.name[i]: return int16(lhs.name[i]) - int16(rhs.name[i]) @@ -127,7 +170,7 @@ proc messagePrinter[MsgType](msg: pointer): string = proc nextMsgResolver[MsgType](msgData: Rlp, future: FutureBase) = var reader = msgData - Future[MsgType](future).complete reader.read(MsgType) + Future[MsgType](future).complete reader.readRecordType(MsgType, MsgType.rlpFieldsCount > 1) proc requestResolver[MsgType](msg: pointer, future: FutureBase) = var f = Future[Option[MsgType]](future) @@ -157,12 +200,12 @@ proc registerMsg(protocol: var ProtocolInfo, nextMsgResolver: NextMsgResolver) = if protocol.messages.len <= id: protocol.messages.setLen(id + 1) - protocol.messages[id] = MessageInfo(id: id, - name: name, - thunk: thunk, - printer: printer, - requestResolver: requestResolver, - nextMsgResolver: nextMsgResolver) + protocol.messages[id] = MessageInfo.init(id = id, + name = name, + thunk = thunk, + printer = printer, + requestResolver = requestResolver, + nextMsgResolver = nextMsgResolver) proc registerProtocol(protocol: ProtocolInfo) = # TODO: This can be done at compile-time in the future @@ -177,13 +220,24 @@ proc registerProtocol(protocol: ProtocolInfo) = # Message composition and encryption # +proc protocolOffset(peer: Peer, Protocol: type): int = + peer.dispatcher.protocolOffsets[Protocol.protocolInfo.index] + +proc perPeerMsgId(peer: Peer, proto: type, msgId: int): int {.inline.} = + result = msgId + if not peer.dispatcher.isNil: + result += peer.protocolOffset(proto) + +proc perPeerMsgId*(peer: Peer, MsgType: type): int {.inline.} = + peer.perPeerMsgId(MsgType.msgProtocol, MsgType.msgId) + proc writeMsgId(p: ProtocolInfo, msgId: int, peer: Peer, rlpOut: var RlpWriter) = let baseMsgId = peer.dispatcher.protocolOffsets[p.index] doAssert baseMsgId != -1 rlpOut.append(baseMsgId + msgId) -proc dispatchMsg(peer: Peer, msgId: int, msgData: var Rlp): Future[void] = +proc invokeThunk*(peer: Peer, msgId: int, msgData: var Rlp): Future[void] = template invalidIdError: untyped = raise newException(ValueError, "RLPx message with an invalid id " & $msgId & @@ -193,26 +247,40 @@ proc dispatchMsg(peer: Peer, msgId: int, msgData: var Rlp): Future[void] = let thunk = peer.dispatcher.messages[msgId].thunk if thunk == nil: invalidIdError() - return thunk(peer, msgData) + return thunk(peer, msgId, msgData) -proc sendMsg(p: Peer, data: Bytes) {.async.} = - # var rlp = rlpFromBytes(data) - # echo "sending: ", rlp.read(int) - # echo "payload: ", rlp.inspect - var cipherText = encryptMsg(data, p.secretsState) - discard await p.transp.write(cipherText) +proc linkSendFailureToReqFuture[S, R](sendFut: Future[S], resFut: Future[R]) = + sendFut.addCallback() do(arg: pointer): + if not sendFut.error.isNil: + resFut.fail(sendFut.error) + +proc sendMsg*(peer: Peer, data: Bytes) {.async.} = + trace "sending msg", peer, msg = getMsgName(peer, rlpFromBytes(data).read(int)) + + var cipherText = encryptMsg(data, peer.secretsState) + try: + discard await peer.transport.write(cipherText) + except: + await peer.disconnect(TcpError) + raise + +proc send*[Msg](peer: Peer, msg: Msg): Future[void] = + var rlpWriter = initRlpWriter() + rlpWriter.append perPeerMsgId(peer, Msg) + rlpWriter.appendRecordType(msg, Msg.rlpFieldsCount > 1) + peer.sendMsg rlpWriter.finish proc registerRequest*(peer: Peer, - timeout: int, - responseFuture: FutureBase, - responseMsgId: int): int = - result = peer.nextReqId - inc peer.nextReqId + timeout: int, + responseFuture: FutureBase, + responseMsgId: int): int = + inc peer.lastReqId + result = peer.lastReqId let timeoutAt = fastEpochTime() + uint64(timeout) - let req = OutstandingRequest(reqId: result, - future: responseFuture, - timeoutAt: timeoutAt) + let req = OutstandingRequest.init(id = result, + future = responseFuture, + timeoutAt = timeoutAt) peer.outstandingRequests[responseMsgId].addLast req assert(not peer.dispatcher.isNil) @@ -229,7 +297,7 @@ proc resolveResponseFuture(peer: Peer, msgId: int, msg: pointer, reqId: int) = remotePeer = peer.remote template resolve(future) = - peer.dispatcher.messages[msgId].requestResolver(msg, future) + (peer.dispatcher.messages[msgId].requestResolver)(msg, future) template outstandingReqs: auto = peer.outstandingRequests[msgId] @@ -267,7 +335,7 @@ proc resolveResponseFuture(peer: Peer, msgId: int, msg: pointer, reqId: int) = # correctly (because then, we'll be reusing the same reqIds for different # types of requests). Alternatively, we can assign a separate interval in # the `reqId` space for each type of response. - if reqId >= peer.nextReqId: + if reqId > peer.lastReqId: warn "RLPx response without a matching request" return @@ -287,7 +355,7 @@ proc resolveResponseFuture(peer: Peer, msgId: int, msg: pointer, reqId: int) = # more work to do: return - if req.reqId == reqId: + if req.id == reqId: resolve req.future # Here we'll remove the found request by swapping # it with the last one in the deque (if necessary): @@ -301,19 +369,22 @@ proc resolveResponseFuture(peer: Peer, msgId: int, msg: pointer, reqId: int) = debug "late or duplicate reply for a RLPx request" -template protocolOffset(peer: Peer, Protocol: type): int = - peer.dispatcher.protocolOffsets[Protocol.protocolInfo.index] - proc recvMsg*(peer: Peer): Future[tuple[msgId: int, msgData: Rlp]] {.async.} = ## This procs awaits the next complete RLPx message in the TCP stream var headerBytes: array[32, byte] - await peer.transp.readExactly(addr headerBytes[0], 32) + await peer.transport.readExactly(addr headerBytes[0], 32) var msgSize: int if decryptHeaderAndGetMsgSize(peer.secretsState, headerBytes, msgSize) != RlpxStatus.Success: - return (-1, zeroBytesRlp) + await peer.disconnectAndRaise(BreachOfProtocol, + "Cannot decrypt RLPx frame header") + + trace "waiting for message bytes", peer, msgSize + if msgSize > maxMsgSize: + await peer.disconnectAndRaise(BreachOfProtocol, + "RLPx message exceeds maximum size") let remainingBytes = encryptedLength(msgSize) - 32 # TODO: Migrate this to a thread-local seq @@ -323,7 +394,7 @@ proc recvMsg*(peer: Peer): Future[tuple[msgId: int, msgData: Rlp]] {.async.} = # also be useuful for chunked messages where part of the buffer may have # been processed and needs filling in var encryptedBytes = newSeq[byte](remainingBytes) - await peer.transp.readExactly(addr encryptedBytes[0], len(encryptedBytes)) + await peer.transport.readExactly(addr encryptedBytes[0], len(encryptedBytes)) let decryptedMaxLength = decryptedLength(msgSize) var @@ -332,27 +403,20 @@ proc recvMsg*(peer: Peer): Future[tuple[msgId: int, msgData: Rlp]] {.async.} = if decryptBody(peer.secretsState, encryptedBytes, msgSize, decryptedBytes, decryptedBytesCount) != RlpxStatus.Success: - return (-1, zeroBytesRlp) + await peer.disconnectAndRaise(BreachOfProtocol, + "Cannot decrypt RLPx frame body") decryptedBytes.setLen(decryptedBytesCount) var rlp = rlpFromBytes(decryptedBytes.toRange) - let msgId = rlp.read(int) - # if not peer.dispatcher.isNil: - # - # echo "Read msg: ", peer.dispatcher.messages[msgId].name - # else: - # echo "Read msg: ", msgId - return (msgId, rlp) -proc perPeerMsgId(peer: Peer, proto: type, msgId: int): int {.inline.} = - result = msgId - if not peer.dispatcher.isNil: - result += peer.protocolOffset(proto) + try: + let msgid = rlp.read(int) + return (msgId, rlp) + except RlpError: + await peer.disconnectAndRaise(BreachOfProtocol, + "Cannot read RLPx message id") -proc perPeerMsgId(peer: Peer, MsgType: type): int {.inline.} = - peer.perPeerMsgId(MsgType.msgProtocol, MsgType.msgId) - -proc checkedRlpRead(r: var Rlp, MsgType: type): auto {.inline.} = +proc checkedRlpRead(peer: Peer, r: var Rlp, MsgType: type): auto {.inline.} = let tmp = r when defined(release): return r.read(MsgType) @@ -362,26 +426,31 @@ proc checkedRlpRead(r: var Rlp, MsgType: type): auto {.inline.} = except: # echo "Failed rlp.read:", tmp.inspect error "Failed rlp.read", + peer = peer, msg = MsgType.name, exception = getCurrentExceptionMsg() # dataHex = r.rawData.toSeq().toHex() raise -proc waitSingleMsg*(peer: Peer, MsgType: type): Future[MsgType] {.async.} = +proc waitSingleMsg(peer: Peer, MsgType: type): Future[MsgType] {.async.} = let wantedId = peer.perPeerMsgId(MsgType) while true: var (nextMsgId, nextMsgData) = await peer.recvMsg() + if nextMsgId == wantedId: - return nextMsgData.checkedRlpRead(MsgType) + try: + return checkedRlpRead(peer, nextMsgData, MsgType) + except RlpError: + await peer.disconnectAndRaise(BreachOfProtocol, + "Invalid RLPx message body") elif nextMsgId == 1: # p2p.disconnect - let reason = nextMsgData.listElem(0).toInt(uint32).DisconnectionReason - let e = newException(UnexpectedDisconnectError, "Unexpected disconnect") - e.reason = reason - raise e + raisePeerDisconnected("Unexpected disconnect", + DisconnectionReason nextMsgData.listElem(0).toInt(uint32)) else: - warn "Dropped RLPX message", msg = peer.dispatcher.messages[nextMsgId].name + warn "Dropped RLPX message", + msg = peer.dispatcher.messages[nextMsgId].name proc nextMsg*(peer: Peer, MsgType: type): Future[MsgType] = ## This procs awaits a specific RLPx message. @@ -393,23 +462,31 @@ proc nextMsg*(peer: Peer, MsgType: type): Future[MsgType] = if not f.isNil: return Future[MsgType](f) - new result + newFuture result peer.awaitedMessages[wantedId] = result proc dispatchMessages*(peer: Peer) {.async.} = while true: var (msgId, msgData) = await peer.recvMsg() + trace "received msg ", peer, msg = getMsgName(peer, msgId) + # rpl = msgData.inspect - # echo "got msg(", msgId, "): ", msgData.inspect - if msgData.listLen != 0: - # TODO: this should be `enterList` - msgData = msgData.listElem(0) + if msgId == 1: # p2p.disconnect + await peer.transport.closeWait() + debug "remote peer disconnected", peer, + reason = msgData.listElem(0).toInt(uint32).DisconnectionReason + break - await peer.dispatchMsg(msgId, msgData) + try: + await peer.invokeThunk(msgId, msgData) + except RlpError: + error "endind dispatchMessages loop", peer, err = getCurrentExceptionMsg() + await peer.disconnect(BreachOfProtocol) + return if peer.awaitedMessages[msgId] != nil: let msgInfo = peer.dispatcher.messages[msgId] - msgInfo.nextMsgResolver(msgData, peer.awaitedMessages[msgId]) + (msgInfo.nextMsgResolver)(msgData, peer.awaitedMessages[msgId]) peer.awaitedMessages[msgId] = nil iterator typedParams(n: NimNode, skip = 0): (NimNode, NimNode) = @@ -429,7 +506,7 @@ proc chooseFieldType(n: NimNode): NimNode = result = n if n.kind == nnkBracketExpr and eqIdent(n[0], "openarray"): result = n.copyNimTree - result[0] = newIdentNode("seq") + result[0] = ident("seq") proc getState(peer: Peer, proto: ProtocolInfo): RootRef = peer.protocolStates[proto.index] @@ -442,29 +519,32 @@ template state*(peer: Peer, Protocol: type): untyped = ## Returns the state object of a particular protocol for a ## particular connection. bind getState - cast[ref Protocol.State](getState(peer, Protocol.protocolInfo)) + cast[Protocol.State](getState(peer, Protocol.protocolInfo)) -proc getNetworkState(peer: Peer, proto: ProtocolInfo): RootRef = - peer.network.protocolStates[proto.index] +proc getNetworkState(node: EthereumNode, proto: ProtocolInfo): RootRef = + node.protocolStates[proto.index] + +template protocolState*(node: EthereumNode, Protocol: type): untyped = + bind getNetworkState + cast[Protocol.NetworkState](getNetworkState(node, Protocol.protocolInfo)) template networkState*(connection: Peer, Protocol: type): untyped = ## Returns the network state object of a particular protocol for a ## particular connection. - bind getNetworkState - cast[ref Protocol.NetworkState](connection.getNetworkState(Protocol.protocolInfo)) + protocolState(connection.network, Protocol) -proc initProtocolState*[T](state: var T, x: Peer|EthereumNode) = discard +proc initProtocolState*[T](state: T, x: Peer|EthereumNode) = discard proc createPeerState[ProtocolState](peer: Peer): RootRef = var res = new ProtocolState mixin initProtocolState - initProtocolState(res[], peer) + initProtocolState(res, peer) return cast[RootRef](res) proc createNetworkState[NetworkState](network: EthereumNode): RootRef = var res = new NetworkState mixin initProtocolState - initProtocolState(res[], network) + initProtocolState(res, network) return cast[RootRef](res) proc popTimeoutParam(n: NimNode): NimNode = @@ -475,41 +555,66 @@ proc popTimeoutParam(n: NimNode): NimNode = result = lastParam n.params.del(n.params.len - 1) -proc linkSendFutureToResult[S, R](sendFut: Future[S], resFut: Future[R]) = - sendFut.addCallback() do(arg: pointer): - if not sendFut.error.isNil: - resFut.fail(sendFut.error) +proc verifyStateType(t: NimNode): NimNode = + result = t[1] + if result.kind == nnkSym and $result == "nil": + return nil + if result.kind != nnkBracketExpr or $result[0] != "ref": + macros.error($result & " must be a ref type") -macro rlpxProtocol*(protoIdentifier: untyped, - version: static[int], - body: untyped): untyped = +macro rlpxProtocolImpl(name: static[string], + version: static[uint], + body: untyped, + useRequestIds: static[bool] = true, + timeout: static[int] = defaultReqTimeout, + shortName: static[string] = "", + outgoingRequestDecorator: untyped = nil, + incomingRequestDecorator: untyped = nil, + incomingRequestThunkDecorator: untyped = nil, + incomingResponseDecorator: untyped = nil, + incomingResponseThunkDecorator: untyped = nil, + peerState = type(nil), + networkState = type(nil)): untyped = ## The macro used to defined RLPx sub-protocols. See README. var + # XXX: deal with a Nim bug causing the macro params to be + # zero when they are captured by a closure: + outgoingRequestDecorator = outgoingRequestDecorator + incomingRequestDecorator = incomingRequestDecorator + incomingRequestThunkDecorator = incomingRequestThunkDecorator + incomingResponseDecorator = incomingResponseDecorator + incomingResponseThunkDecorator = incomingResponseThunkDecorator + useRequestIds = useRequestIds + version = version + defaultTimeout = timeout + nextId = 0 + protoName = name + shortName = if shortName.len > 0: shortName else: protoName outTypes = newNimNode(nnkStmtList) outSendProcs = newNimNode(nnkStmtList) outRecvProcs = newNimNode(nnkStmtList) outProcRegistrations = newNimNode(nnkStmtList) - protoName = $protoIdentifier - protoNameIdent = newIdentNode(protoName) - resultIdent = newIdentNode "result" - protocol = genSym(nskVar, protoName & "Proto") - isSubprotocol = version > 0 - stateType: NimNode = nil - networkStateType: NimNode = nil + protoNameIdent = ident(protoName) + resultIdent = ident "result" + perProtocolMsgId = ident"perProtocolMsgId" + protocol = ident(protoName & "Protocol") + isSubprotocol = version > 0'u + peerState = verifyStateType peerState.getType + networkState = verifyStateType networkState.getType handshake = newNilLit() disconnectHandler = newNilLit() - useRequestIds = true Option = bindSym "Option" # XXX: Binding the int type causes instantiation failure for some reason # Int = bindSym "int" - Int = newIdentNode "int" + Int = ident "int" Peer = bindSym "Peer" append = bindSym "append" createNetworkState = bindSym "createNetworkState" createPeerState = bindSym "createPeerState" finish = bindSym "finish" initRlpWriter = bindSym "initRlpWriter" + enterList = bindSym "enterList" messagePrinter = bindSym "messagePrinter" newProtocol = bindSym "newProtocol" nextMsgResolver = bindSym "nextMsgResolver" @@ -525,15 +630,23 @@ macro rlpxProtocol*(protoIdentifier: untyped, getState = bindSym "getState" getNetworkState = bindSym "getNetworkState" perPeerMsgId = bindSym "perPeerMsgId" - linkSendFutureToResult = bindSym "linkSendFutureToResult" + linkSendFailureToReqFuture = bindSym "linkSendFailureToReqFuture" # By convention, all Ethereum protocol names must be abbreviated to 3 letters - assert protoName.len == 3 + assert shortName.len == 3 - proc augmentUserHandler(userHandlerProc: NimNode) = + template applyDecorator(p: NimNode, decorator: NimNode) = + if decorator.kind != nnkNilLit: p.addPragma decorator + + proc augmentUserHandler(userHandlerProc: NimNode, msgId = -1, msgKind = rlpxNotification) = ## Turns a regular proc definition into an async proc and adds ## the helpers for accessing the peer and network protocol states. - userHandlerProc.addPragma newIdentNode"async" + case msgKind + of rlpxRequest: userHandlerProc.applyDecorator incomingRequestDecorator + of rlpxResponse: userHandlerProc.applyDecorator incomingResponseDecorator + else: discard + + userHandlerProc.addPragma ident"async" # We allow the user handler to use `openarray` params, but we turn # those into sequences to make the `async` pragma happy. @@ -541,21 +654,25 @@ macro rlpxProtocol*(protoIdentifier: untyped, var param = userHandlerProc.params[i] param[^2] = chooseFieldType(param[^2]) + var userHandlerDefinitions = newStmtList() + + if msgId >= 0: + userHandlerDefinitions.add quote do: + const `perProtocolMsgId` = `msgId` + # Define local accessors for the peer and the network protocol states # inside each user message handler proc (e.g. peer.state.foo = bar) - if stateType != nil: - var localStateAccessor = quote: - template state(p: `Peer`): ref `stateType` = - cast[ref `stateType`](`getState`(p, `protocol`)) + if peerState != nil: + userHandlerDefinitions.add quote do: + template state(p: `Peer`): `peerState` = + cast[`peerState`](`getState`(p, `protocol`)) - userHandlerProc.body.insert 0, localStateAccessor + if networkState != nil: + userHandlerDefinitions.add quote do: + template networkState(p: `Peer`): `networkState` = + cast[`networkState`](`getNetworkState`(p.network, `protocol`)) - if networkStateType != nil: - var networkStateAccessor = quote: - template networkState(p: `Peer`): ref `networkStateType` = - cast[ref `networkStateType`](`getNetworkState`(p, `protocol`)) - - userHandlerProc.body.insert 0, networkStateAccessor + userHandlerProc.body.insert 0, userHandlerDefinitions proc liftEventHandler(doBlock: NimNode, handlerName: string): NimNode = ## Turns a "named" do block to a regular async proc @@ -571,25 +688,33 @@ macro rlpxProtocol*(protoIdentifier: untyped, msgKind = rlpxNotification, responseMsgId = -1, responseRecord: NimNode = nil): NimNode = + if n[0].kind == nnkPostfix: + macros.error("rlpxProcotol procs are public by default. " & + "Please remove the postfix `*`.", n) + let msgIdent = n.name msgName = $n.name + hasReqIds = useRequestIds and msgKind in {rlpxRequest, rlpxResponse} var paramCount = 0 + userPragmas = n.pragma # variables used in the sending procs - msgRecipient = genSym(nskParam, "msgRecipient") + msgRecipient = ident"msgRecipient" reqTimeout: NimNode - rlpWriter = genSym(nskVar, "writer") + rlpWriter = ident"writer" appendParams = newNimNode(nnkStmtList) - sentReqId = genSym(nskLet, "reqId") + reqId = ident"reqId" + perPeerMsgIdVar = ident"perPeerMsgId" # variables used in the receiving procs - msgSender = genSym(nskParam, "msgSender") - receivedRlp = genSym(nskVar, "rlp") - receivedMsg = genSym(nskVar, "msg") + msgSender = ident"msgSender" + receivedRlp = ident"rlp" + receivedMsg = ident"msg" readParams = newNimNode(nnkStmtList) + readParamsPrelude = newNimNode(nnkStmtList) callResolvedResponseFuture = newNimNode(nnkStmtList) # nodes to store the user-supplied message handling proc if present @@ -598,7 +723,7 @@ macro rlpxProtocol*(protoIdentifier: untyped, awaitUserHandler = newStmtList() # a record type associated with the message - msgRecord = genSym(nskType, msgName & "Obj") + msgRecord = newIdentNode(msgName & "Obj") msgRecordFields = newTree(nnkRecList) msgRecordBody = newTree(nnkObjectTy, newEmptyNode(), @@ -606,6 +731,10 @@ macro rlpxProtocol*(protoIdentifier: untyped, msgRecordFields) result = msgRecord + if hasReqIds: + # Messages using request Ids + readParamsPrelude.add quote do: + let `reqId` = `read`(`receivedRlp`, int) case msgKind of rlpxNotification: discard @@ -616,13 +745,12 @@ macro rlpxProtocol*(protoIdentifier: untyped, # When the timeout is not specified, we use a default one. reqTimeout = popTimeoutParam(n) if reqTimeout == nil: - reqTimeout = newTree(nnkIdentDefs, - genSym(nskParam, "timeout"), - Int, newLit(defaultReqTimeout)) + reqTimeout = newTree(nnkIdentDefs, + ident"timeout", + Int, newLit(defaultTimeout)) - let expectedMsgId = newCall(perPeerMsgId, msgRecipient, - protoNameIdent, - newLit(responseMsgId)) + let reqToResponseOffset = responseMsgId - msgId + let responseMsgId = quote do: `perPeerMsgIdVar` + `reqToResponseOffset` # Each request is registered so we can resolve it when the response # arrives. There are two types of protocols: LES-like protocols use @@ -632,43 +760,42 @@ macro rlpxProtocol*(protoIdentifier: untyped, let registerRequestCall = newCall(registerRequest, msgRecipient, reqTimeout[0], resultIdent, - expectedMsgId) - if useRequestIds: - inc paramCount + responseMsgId) + if hasReqIds: appendParams.add quote do: - new `resultIdent` - let `sentReqId` = `registerRequestCall` - `append`(`rlpWriter`, `sentReqId`) + newFuture `resultIdent` + let `reqId` = `registerRequestCall` + `append`(`rlpWriter`, `reqId`) else: appendParams.add quote do: - new `resultIdent` + newFuture `resultIdent` discard `registerRequestCall` + of rlpxResponse: - let expectedMsgId = newCall(perPeerMsgId, msgSender, msgRecord) - if useRequestIds: - var reqId = genSym(nskLet, "reqId") - - # Messages using request Ids - readParams.add quote do: - let `reqId` = `read`(`receivedRlp`, int) - - callResolvedResponseFuture.add quote do: - `resolveResponseFuture`(`msgSender`, `expectedMsgId`, addr(`receivedMsg`), `reqId`) - else: - callResolvedResponseFuture.add quote do: - `resolveResponseFuture`(`msgSender`, `expectedMsgId`, addr(`receivedMsg`), -1) + let reqIdVal = if hasReqIds: `reqId` else: newLit(-1) + callResolvedResponseFuture.add quote do: + `resolveResponseFuture`(`msgSender`, + `perPeerMsgId`(`msgSender`, `msgRecord`), + addr(`receivedMsg`), + `reqIdVal`) + if hasReqIds: + appendParams.add newCall(append, rlpWriter, reqId) if n.body.kind != nnkEmpty: # implement the receiving thunk proc that deserialzed the # message parameters and calls the user proc: userHandlerProc = n.copyNimTree userHandlerProc.name = genSym(nskProc, msgName) - augmentUserHandler userHandlerProc + augmentUserHandler userHandlerProc, msgId, msgKind # This is the call to the user supplied handled. Here we add only the # initial peer param, while the rest of the params will be added later. userHandlerCall = newCall(userHandlerProc.name, msgSender) + if hasReqIds: + userHandlerProc.params.insert(2, newIdentDefs(reqId, ident"int")) + userHandlerCall.add reqId + # When there is a user handler, it must be awaited in the thunk proc. # Above, by default `awaitUserHandler` is set to a no-op statement list. awaitUserHandler = newCall("await", userHandlerCall) @@ -680,35 +807,51 @@ macro rlpxProtocol*(protoIdentifier: untyped, # This is a fragment of the sending proc that # serializes each of the passed parameters: - appendParams.add quote do: - `append`(`rlpWriter`, `param`) + appendParams.add newCall(append, rlpWriter, param) # Each message has a corresponding record type. # Here, we create its fields one by one: msgRecordFields.add newTree(nnkIdentDefs, - param, chooseFieldType(paramType), newEmptyNode()) + newTree(nnkPostfix, ident("*"), param), # The fields are public + chooseFieldType(paramType), # some types such as openarray + # are automatically remapped + newEmptyNode()) # The received RLP data is deserialized to a local variable of # the message-specific type. This is done field by field here: let msgNameLit = newLit(msgName) readParams.add quote do: - `receivedMsg`.`param` = `checkedRlpRead`(`receivedRlp`, `paramType`) + `receivedMsg`.`param` = `checkedRlpRead`(`msgSender`, `receivedRlp`, `paramType`) # If there is user message handler, we'll place a call to it by # unpacking the fields of the received message: if userHandlerCall != nil: userHandlerCall.add newDotExpr(receivedMsg, param) - let thunkName = newIdentNode(msgName & "_thunk") + if paramCount > 1: + readParamsPrelude.add newCall(enterList, receivedRlp) - outRecvProcs.add quote do: - proc `thunkName`(`msgSender`: `Peer`, data: Rlp) {.async.} = + let thunkName = ident(msgName & "_thunk") + var thunkProc = quote do: + proc `thunkName`(`msgSender`: `Peer`, _: int, data: Rlp) = var `receivedRlp` = data var `receivedMsg` {.noinit.}: `msgRecord` + `readParamsPrelude` `readParams` `awaitUserHandler` `callResolvedResponseFuture` + for p in userPragmas: thunkProc.addPragma p + + case msgKind + of rlpxRequest: thunkProc.applyDecorator incomingRequestThunkDecorator + of rlpxResponse: thunkProc.applyDecorator incomingResponseThunkDecorator + else: discard + + thunkProc.addPragma ident"async" + + outRecvProcs.add thunkProc + outTypes.add quote do: # This is a type featuring a single field for each message param: type `msgRecord`* = `msgRecordBody` @@ -727,41 +870,68 @@ macro rlpxProtocol*(protoIdentifier: untyped, msgSendProc.params[1][0] = msgRecipient # Add a timeout parameter for all request procs - if msgKind == rlpxRequest: msgSendProc.params.add reqTimeout + case msgKind + of rlpxRequest: + msgSendProc.params.add reqTimeout + of rlpxResponse: + if useRequestIds: + msgSendProc.params.insert 2, newIdentDefs(reqId, ident"int") + else: discard # We change the return type of the sending proc to a Future. # If this is a request proc, the future will return the response record. - let rt = if msgKind != rlpxRequest: newIdentNode"void" + let rt = if msgKind != rlpxRequest: ident"void" else: newTree(nnkBracketExpr, Option, responseRecord) - msgSendProc.params[0] = newTree(nnkBracketExpr, newIdentNode("Future"), rt) + msgSendProc.params[0] = newTree(nnkBracketExpr, ident("Future"), rt) - let writeMsgId = if isSubprotocol: - quote: `writeMsgId`(`protocol`, `msgId`, `msgRecipient`, `rlpWriter`) - else: - quote: `append`(`rlpWriter`, `msgId`) + let msgBytes = ident"msgBytes" - var sendCall = newCall(sendMsg, msgRecipient, newCall(finish, rlpWriter)) + let finalizeRequest = quote do: + let `msgBytes` = `finish`(`rlpWriter`) + + var sendCall = newCall(sendMsg, msgRecipient, msgBytes) let senderEpilogue = if msgKind == rlpxRequest: # In RLPx requests, the returned future was allocated here and passed # to `registerRequest`. It's already assigned to the result variable # of the proc, so we just wait for the sending operation to complete # and we return in a normal way. (the waiting is done, so we can catch # any possible errors). - quote: `linkSendFutureToResult`(`sendCall`, `resultIdent`) + quote: `linkSendFailureToReqFuture`(`sendCall`, `resultIdent`) else: # In normal RLPx messages, we are returning the future returned by the # `sendMsg` call. quote: return `sendCall` + let `perPeerMsgIdValue` = if isSubprotocol: + newCall(perPeerMsgId, msgRecipient, protoNameIdent, perProtocolMsgId) + else: + perProtocolMsgId + + if paramCount > 1: + # In case there are more than 1 parameter, + # the params must be wrapped in a list: + appendParams = newStmtList( + newCall(startList, rlpWriter, newLit(paramCount)), + appendParams) + + # Make the send proc public + msgSendProc.name = newTree(nnkPostfix, ident("*"), msgSendProc.name) + # let paramCountNode = newLit(paramCount) msgSendProc.body = quote do: var `rlpWriter` = `initRlpWriter`() - `writeMsgId` - if `paramCount` > 1: - `startList`(`rlpWriter`, `paramCount`) + let `perProtocolMsgId` = `msgId` + let `perPeerMsgIdVar` = `perPeerMsgIdValue` + + `append`(`rlpWriter`, `perPeerMsgIdVar`) `appendParams` + + `finalizeRequest` `senderEpilogue` + if msgKind == rlpxRequest: + msgSendProc.applyDecorator outgoingRequestDecorator + outSendProcs.add msgSendProc outProcRegistrations.add( @@ -779,6 +949,14 @@ macro rlpxProtocol*(protoIdentifier: untyped, # (e.g. p2p) type `protoNameIdent`* = object + if peerState != nil: + outTypes.add quote do: + template State*(P: type `protoNameIdent`): type = `peerState` + + if networkState != nil: + outTypes.add quote do: + template NetworkState*(P: type `protoNameIdent`): type = `networkState` + for n in body: case n.kind of {nnkCall, nnkCommand}: @@ -824,33 +1002,6 @@ macro rlpxProtocol*(protoIdentifier: untyped, disconnectHandler = liftEventHandler(n[1], "PeerDisconnect") else: macros.error(repr(n) & " is not a recognized call in RLPx protocol definitions", n) - - of nnkAsgn: - if eqIdent(n[0], "useRequestIds"): - useRequestIds = $n[1] == "true" - else: - macros.error(repr(n[0]) & " is not a recognized protocol option") - - of nnkTypeSection: - outTypes.add n - for typ in n: - if eqIdent(typ[0], "State"): - stateType = genSym(nskType, protoName & "State") - typ[0] = stateType - outTypes.add quote do: - template State*(P: type `protoNameIdent`): type = - `stateType` - - elif eqIdent(typ[0], "NetworkState"): - networkStateType = genSym(nskType, protoName & "NetworkState") - typ[0] = networkStateType - outTypes.add quote do: - template NetworkState*(P: type `protoNameIdent`): type = - `networkStateType` - - else: - macros.error("The only type names allowed within a RLPx protocol definition are 'State' and 'NetworkState'") - of nnkProcDef: discard addMsgHandler(nextId, n) inc nextId @@ -861,17 +1012,17 @@ macro rlpxProtocol*(protoIdentifier: untyped, else: macros.error("illegal syntax in a RLPx protocol definition", n) - let peerInit = if stateType == nil: newNilLit() - else: newTree(nnkBracketExpr, createPeerState, stateType) + let peerInit = if peerState == nil: newNilLit() + else: newTree(nnkBracketExpr, createPeerState, peerState) - let netInit = if networkStateType == nil: newNilLit() - else: newTree(nnkBracketExpr, createNetworkState, stateType) + let netInit = if networkState == nil: newNilLit() + else: newTree(nnkBracketExpr, createNetworkState, networkState) result = newNimNode(nnkStmtList) result.add outTypes result.add quote do: # One global variable per protocol holds the protocol run-time data - var `protocol` = `newProtocol`(`protoName`, `version`, `peerInit`, `netInit`) + var `protocol` = `newProtocol`(`shortName`, `version`, `peerInit`, `netInit`) # The protocol run-time data is available as a pseudo-field # (e.g. `p2p.protocolInfo`) @@ -885,7 +1036,18 @@ macro rlpxProtocol*(protoIdentifier: untyped, when isMainModule: echo repr(result) # echo repr(result) -rlpxProtocol p2p, 0: +macro rlpxProtocol*(protocolOptions: untyped, body: untyped): untyped = + let protoName = $(protocolOptions[0]) + result = protocolOptions + result[0] = bindSym"rlpxProtocolImpl" + result.add(newTree(nnkExprEqExpr, + ident("name"), + newLit(protoName))) + result.add(newTree(nnkExprEqExpr, + ident("body"), + body)) + +rlpxProtocol p2p(version = 0): proc hello(peer: Peer, version: uint, clientId: string, @@ -901,31 +1063,45 @@ rlpxProtocol p2p, 0: proc pong(peer: Peer) = discard +proc removePeer(network: EthereumNode, peer: Peer) = + if network.peerPool != nil: + network.peerPool.connectedNodes.del(peer.remote) + + for observer in network.peerPool.observers.values: + if not observer.onPeerDisconnected.isNil: + observer.onPeerDisconnected(peer) + +proc callDisconnectHandlers(peer: Peer, reason: DisconnectionReason): Future[void] = + var futures = newSeqOfCap[Future[void]](rlpxProtocols.len) + + for protocol in peer.dispatcher.activeProtocols: + if protocol.disconnectHandler != nil: + futures.add((protocol.disconnectHandler)(peer, reason)) + + return all(futures) + proc disconnect*(peer: Peer, reason: DisconnectionReason) {.async.} = if peer.connectionState notin {Disconnecting, Disconnected}: peer.connectionState = Disconnecting - await peer.sendDisconnectMsg(reason) - peer.connectionState = Disconnected - # TODO: Any other clean up required? - -template `^`(arr): auto = - # passes a stack array with a matching `arrLen` - # variable as an open array - arr.toOpenArray(0, `arr Len` - 1) + try: + # TODO: investigate the failure here + if not peer.transport.closed and false: + await peer.sendDisconnectMsg(reason) + finally: + if not peer.dispatcher.isNil: + await callDisconnectHandlers(peer, reason) + peer.connectionState = Disconnected + removePeer(peer.network, peer) proc validatePubKeyInHello(msg: p2p.hello, pubKey: PublicKey): bool = var pk: PublicKey recoverPublicKey(msg.nodeId, pk) == EthKeysStatus.Success and pk == pubKey -proc check(status: AuthStatus) = - if status != AuthStatus.Success: - raise newException(Exception, "Error: " & $status) - proc performSubProtocolHandshakes(peer: Peer) {.async.} = var subProtocolsHandshakes = newSeqOfCap[Future[void]](rlpxProtocols.len) for protocol in peer.dispatcher.activeProtocols: if protocol.handshake != nil: - subProtocolsHandshakes.add protocol.handshake(peer) + subProtocolsHandshakes.add((protocol.handshake)(peer)) await all(subProtocolsHandshakes) peer.connectionState = Connected @@ -935,9 +1111,8 @@ proc checkUselessPeer(peer: Peer) {.inline.} = # XXX: Send disconnect + UselessPeer raise newException(UselessPeerError, "Useless peer") -proc postHelloSteps(peer: Peer, h: p2p.hello): Future[void] = - peer.dispatcher = getDispatcher(peer.network, h.capabilities) - +proc initPeerState*(peer: Peer, capabilities: openarray[Capability]) = + peer.dispatcher = getDispatcher(peer.network, capabilities) checkUselessPeer(peer) # The dispatcher has determined our message ID sequence. @@ -952,7 +1127,7 @@ proc postHelloSteps(peer: Peer, h: p2p.hello): Future[void] = # of the potentially concurrent calls to `nextMsg`. peer.awaitedMessages.newSeq(peer.dispatcher.messages.len) - peer.nextReqId = 1 + peer.lastReqId = 0 # Initialize all the active protocol states newSeq(peer.protocolStates, rlpxProtocols.len) @@ -961,8 +1136,26 @@ proc postHelloSteps(peer: Peer, h: p2p.hello): Future[void] = if peerStateInit != nil: peer.protocolStates[protocol.index] = peerStateInit(peer) +proc postHelloSteps(peer: Peer, h: p2p.hello): Future[void] = + initPeerState(peer, h.capabilities) + + var messageProcessingLoop = peer.dispatchMessages() + + messageProcessingLoop.callback = proc(p: pointer) {.gcsafe.} = + if messageProcessingLoop.failed: + asyncCheck peer.disconnect(ClientQuitting) + return performSubProtocolHandshakes(peer) +template `^`(arr): auto = + # passes a stack array with a matching `arrLen` + # variable as an open array + arr.toOpenArray(0, `arr Len` - 1) + +proc check(status: AuthStatus) = + if status != AuthStatus.Success: + raise newException(Exception, "Error: " & $status) + proc initSecretState(hs: var Handshake, authMsg, ackMsg: openarray[byte], p: Peer) = var secrets: ConnectionSecret @@ -978,7 +1171,7 @@ proc rlpxConnect*(node: EthereumNode, remote: Node): Future[Peer] {.async.} = let ta = initTAddress(remote.node.address.ip, remote.node.address.tcpPort) var ok = false try: - result.transp = await connect(ta) + result.transport = await connect(ta) var handshake = newHandshake({Initiator, EIP8}) handshake.host = node.keys @@ -986,19 +1179,19 @@ proc rlpxConnect*(node: EthereumNode, remote: Node): Future[Peer] {.async.} = var authMsg: array[AuthMessageMaxEIP8, byte] var authMsgLen = 0 check authMessage(handshake, remote.node.pubkey, authMsg, authMsgLen) - var res = result.transp.write(addr authMsg[0], authMsgLen) + var res = result.transport.write(addr authMsg[0], authMsgLen) let initialSize = handshake.expectedLength var ackMsg = newSeqOfCap[byte](1024) ackMsg.setLen(initialSize) - await result.transp.readExactly(addr ackMsg[0], len(ackMsg)) + await result.transport.readExactly(addr ackMsg[0], len(ackMsg)) var ret = handshake.decodeAckMessage(ackMsg) if ret == AuthStatus.IncompleteError: ackMsg.setLen(handshake.expectedLength) - await result.transp.readExactly(addr ackMsg[initialSize], - len(ackMsg) - initialSize) + await result.transport.readExactly(addr ackMsg[initialSize], + len(ackMsg) - initialSize) ret = handshake.decodeAckMessage(ackMsg) check ret @@ -1007,7 +1200,7 @@ proc rlpxConnect*(node: EthereumNode, remote: Node): Future[Peer] {.async.} = # if handshake.remoteHPubkey != remote.node.pubKey: # raise newException(Exception, "Remote pubkey is wrong") - asyncCheck result.hello(baseProtocolVersion, + asyncCheck result.hello(devp2pVersion, node.clientId, node.rlpxCapabilities, uint(node.address.tcpPort), @@ -1020,13 +1213,13 @@ proc rlpxConnect*(node: EthereumNode, remote: Node): Future[Peer] {.async.} = await postHelloSteps(result, response) ok = true - except UnexpectedDisconnectError as e: + except PeerDisconnected as e: if e.reason != TooManyPeers: debug "Unexpected disconnect during rlpxConnect", reason = e.reason except TransportIncompleteError: debug "Connection dropped in rlpxConnect", remote except UselessPeerError: - debug "Useless peer" + debug "Useless peer ", peer = remote except RlpTypeMismatch: # Some peers report capabilities with names longer than 3 chars. We ignore # those for now. Maybe we should allow this though. @@ -1037,14 +1230,14 @@ proc rlpxConnect*(node: EthereumNode, remote: Node): Future[Peer] {.async.} = err = getCurrentExceptionMsg() if not ok: - if not isNil(result.transp): - result.transp.close() + if not isNil(result.transport): + result.transport.close() result = nil proc rlpxAccept*(node: EthereumNode, - transp: StreamTransport): Future[Peer] {.async.} = + transport: StreamTransport): Future[Peer] {.async.} = new result - result.transp = transp + result.transport = transport result.network = node var handshake = newHandshake({Responder}) @@ -1054,32 +1247,32 @@ proc rlpxAccept*(node: EthereumNode, let initialSize = handshake.expectedLength var authMsg = newSeqOfCap[byte](1024) authMsg.setLen(initialSize) - await transp.readExactly(addr authMsg[0], len(authMsg)) + await transport.readExactly(addr authMsg[0], len(authMsg)) var ret = handshake.decodeAuthMessage(authMsg) if ret == AuthStatus.IncompleteError: # Eip8 auth message is likely authMsg.setLen(handshake.expectedLength) - await transp.readExactly(addr authMsg[initialSize], - len(authMsg) - initialSize) + await transport.readExactly(addr authMsg[initialSize], + len(authMsg) - initialSize) ret = handshake.decodeAuthMessage(authMsg) check ret var ackMsg: array[AckMessageMaxEIP8, byte] var ackMsgLen: int check handshake.ackMessage(ackMsg, ackMsgLen) - var res = transp.write(addr ackMsg[0], ackMsgLen) + var res = transport.write(addr ackMsg[0], ackMsgLen) initSecretState(handshake, authMsg, ^ackMsg, result) + let listenPort = transport.localAddress().port + await result.hello(devp2pVersion, node.clientId, + node.rlpxCapabilities, listenPort.uint, + node.keys.pubkey.getRaw()) + var response = await result.waitSingleMsg(p2p.hello) if not validatePubKeyInHello(response, handshake.remoteHPubkey): warn "A Remote nodeId is not its public key" # XXX: Do we care? - let listenPort = transp.localAddress().port - await result.hello(baseProtocolVersion, node.clientId, - node.rlpxCapabilities, listenPort.uint, - node.keys.pubkey.getRaw()) - - let remote = transp.remoteAddress() + let remote = transport.remoteAddress() let address = Address(ip: remote.address, tcpPort: remote.port, udpPort: remote.port) result.remote = newNode(initEnode(handshake.remoteHPubkey, address)) @@ -1089,10 +1282,11 @@ proc rlpxAccept*(node: EthereumNode, error "Exception in rlpxAccept", err = getCurrentExceptionMsg(), stackTrace = getCurrentException().getStackTrace() - transp.close() + transport.close() result = nil when isMainModule: + when false: # The assignments below can be used to investigate if the RLPx procs # are considered GcSafe. The short answer is that they aren't, because @@ -1103,11 +1297,11 @@ when isMainModule: GcSafeRecvMsg = proc (peer: Peer): Future[tuple[msgId: int, msgData: Rlp]] {.gcsafe.} - GcSafeAccept = proc (transp: StreamTransport, myKeys: KeyPair): + GcSafeAccept = proc (transport: StreamTransport, myKeys: KeyPair): Future[Peer] {.gcsafe.} var - dispatchMsgPtr = dispatchMsg + dispatchMsgPtr = invokeThunk recvMsgPtr: GcSafeRecvMsg = recvMsg acceptPtr: GcSafeAccept = rlpxAccept diff --git a/eth_p2p/rlpx_protocols/eth_protocol.nim b/eth_p2p/rlpx_protocols/eth_protocol.nim index 0798f06..486f22d 100644 --- a/eth_p2p/rlpx_protocols/eth_protocol.nim +++ b/eth_p2p/rlpx_protocols/eth_protocol.nim @@ -12,9 +12,8 @@ ## https://github.com/ethereum/wiki/wiki/Ethereum-Wire-Protocol import - random, algorithm, hashes, - asyncdispatch2, rlp, stint, eth_common, chronicles, - ../../eth_p2p + asyncdispatch2, stint, chronicles, rlp, eth_common/eth_types, + ../rlpx, ../private/types, ../blockchain_utils, ../../eth_p2p type NewBlockHashesAnnounce* = object @@ -25,26 +24,21 @@ type header: BlockHeader body {.rlpInline.}: BlockBody - NetworkState = object - syncing: bool - - PeerState = object - initialized: bool - bestBlockHash: KeccakHash - bestDifficulty: DifficultyInt + PeerState = ref object + initialized*: bool + bestBlockHash*: KeccakHash + bestDifficulty*: DifficultyInt const - maxStateFetch = 384 - maxBodiesFetch = 128 - maxReceiptsFetch = 256 - maxHeadersFetch = 192 - protocolVersion = 63 - minPeersToStartSync = 2 # Wait for consensus of at least this number of peers before syncing + maxStateFetch* = 384 + maxBodiesFetch* = 128 + maxReceiptsFetch* = 256 + maxHeadersFetch* = 192 + protocolVersion* = 63 -rlpxProtocol eth, protocolVersion: - useRequestIds = false - - type State = PeerState +rlpxProtocol eth(version = protocolVersion, + peerState = PeerState, + useRequestIds = false): onPeerConnected do (peer: Peer): let @@ -58,9 +52,9 @@ rlpxProtocol eth, protocolVersion: bestBlock.blockHash, chain.genesisHash) - let m = await peer.waitSingleMsg(eth.status) + let m = await peer.nextMsg(eth.status) if m.networkId == network.networkId and m.genesisHash == chain.genesisHash: - debug "Suitable peer", peer + debug "suitable peer", peer else: raise newException(UselessPeerError, "Eth handshake params mismatch") peer.state.initialized = true @@ -72,16 +66,7 @@ rlpxProtocol eth, protocolVersion: networkId: uint, totalDifficulty: DifficultyInt, bestHash: KeccakHash, - genesisHash: KeccakHash) = - # verify that the peer is on the same chain: - if peer.network.networkId != networkId or - peer.network.chain.genesisHash != genesisHash: - # TODO: Is there a more specific reason here? - await peer.disconnect(SubprotocolReason) - return - - peer.state.bestBlockHash = bestHash - peer.state.bestDifficulty = totalDifficulty + genesisHash: KeccakHash) proc newBlockHashes(peer: Peer, hashes: openarray[NewBlockHashesAnnounce]) = discard @@ -95,19 +80,7 @@ rlpxProtocol eth, protocolVersion: await peer.disconnect(BreachOfProtocol) return - var headers = newSeqOfCap[BlockHeader](request.maxResults) - let chain = peer.network.chain - var foundBlock: BlockHeader - - if chain.getBlockHeader(request.startBlock, foundBlock): - headers.add foundBlock - - while uint64(headers.len) < request.maxResults: - if not chain.getSuccessorHeader(foundBlock, foundBlock): - break - headers.add foundBlock - - await peer.blockHeaders(headers) + await peer.blockHeaders(peer.network.chain.getBlockHeaders(request)) proc blockHeaders(p: Peer, headers: openarray[BlockHeader]) @@ -117,18 +90,7 @@ rlpxProtocol eth, protocolVersion: await peer.disconnect(BreachOfProtocol) return - var chain = peer.network.chain - - var blockBodies = newSeqOfCap[BlockBody](hashes.len) - for hash in hashes: - let blockBody = chain.getBlockBody(hash) - if not blockBody.isNil: - # TODO: should there be an else clause here. - # Is the peer responsible of figuring out that - # some blocks were not found? - blockBodies.add deref(blockBody) - - await peer.blockBodies(blockBodies) + await peer.blockBodies(peer.network.chain.getBlockBodies(hashes)) proc blockBodies(peer: Peer, blocks: openarray[BlockBody]) @@ -139,18 +101,13 @@ rlpxProtocol eth, protocolVersion: requestResponse: proc getNodeData(peer: Peer, hashes: openarray[KeccakHash]) = - await peer.nodeData([]) + await peer.nodeData(peer.network.chain.getStorageNodes(hashes)) - proc nodeData(peer: Peer, data: openarray[Blob]) = - discard + proc nodeData(peer: Peer, data: openarray[Blob]) requestResponse: proc getReceipts(peer: Peer, hashes: openarray[KeccakHash]) = - await peer.receipts([]) - - proc receipts(peer: Peer, receipts: openarray[Receipt]) = - discard - -proc hash*(p: Peer): Hash {.inline.} = hash(cast[pointer](p)) + await peer.receipts(peer.network.chain.getReceipts(hashes)) + proc receipts(peer: Peer, receipts: openarray[Receipt]) diff --git a/eth_p2p/rlpx_protocols/les/flow_control.nim b/eth_p2p/rlpx_protocols/les/flow_control.nim new file mode 100644 index 0000000..f8196d6 --- /dev/null +++ b/eth_p2p/rlpx_protocols/les/flow_control.nim @@ -0,0 +1,501 @@ +import + tables, sets, + chronicles, asyncdispatch2, rlp, eth_common/eth_types, + ../../rlpx, ../../private/types, private/les_types + +const + maxSamples = 100000 + rechargingScale = 1000000 + + lesStatsKey = "les.flow_control.stats" + lesStatsVer = 0 + +logScope: + topics = "les flow_control" + +# TODO: move this somewhere +proc pop[A, B](t: var Table[A, B], key: A): B = + result = t[key] + t.del(key) + +when LesTime is SomeInteger: + template `/`(lhs, rhs: LesTime): LesTime = + lhs div rhs + +when defined(testing): + var lesTime* = LesTime(0) + template now(): LesTime = lesTime + template advanceTime(t) = lesTime += LesTime(t) + +else: + import times + let startTime = epochTime() + + proc now(): LesTime = + return LesTime((times.epochTime() - startTime) * 1000.0) + +proc addSample(ra: var StatsRunningAverage; x, y: float64) = + if ra.count >= maxSamples: + let decay = float64(ra.count + 1 - maxSamples) / maxSamples + template applyDecay(x) = x -= x * decay + + applyDecay ra.sumX + applyDecay ra.sumY + applyDecay ra.sumXX + applyDecay ra.sumXY + ra.count = maxSamples - 1 + + inc ra.count + ra.sumX += x + ra.sumY += y + ra.sumXX += x * x + ra.sumXY += x * y + +proc calc(ra: StatsRunningAverage): tuple[m, b: float] = + if ra.count == 0: + return + + let count = float64(ra.count) + let d = count * ra.sumXX - ra.sumX * ra.sumX + if d < 0.001: + return (m: ra.sumY / count, b: 0.0) + + result.m = (count * ra.sumXY - ra.sumX * ra.sumY) / d + result.b = (ra.sumY / count) - (result.m * ra.sumX / count) + +proc currentRequestsCosts*(network: LesNetwork, + les: ProtocolInfo): seq[ReqCostInfo] = + # Make sure the message costs are already initialized + doAssert network.messageStats.len > les.messages[^1].id, + "Have you called `initFlowControl`" + + for msg in les.messages: + var (m, b) = network.messageStats[msg.id].calc() + if m < 0: + b += m + m = 0 + + if b < 0: + b = 0 + + result.add ReqCostInfo.init(msgId = msg.id, + baseCost = ReqCostInt(b * 2), + reqCost = ReqCostInt(m * 2)) + +proc persistMessageStats*(db: AbstractChainDB, + network: LesNetwork) = + doAssert db != nil + # XXX: Because of the package_visible_types template magic, Nim complains + # when we pass the messageStats expression directly to `encodeList` + let stats = network.messageStats + db.setSetting(lesStatsKey, rlp.encodeList(lesStatsVer, stats)) + +proc loadMessageStats*(network: LesNetwork, + les: ProtocolInfo, + db: AbstractChainDb): bool = + block readFromDB: + if db == nil: + break readFromDB + + var stats = db.getSetting(lesStatsKey) + if stats.len == 0: + notice "LES stats not present in the database" + break readFromDB + + try: + var statsRlp = rlpFromBytes(stats.toRange) + statsRlp.enterList + + let version = statsRlp.read(int) + if version != lesStatsVer: + notice "Found outdated LES stats record" + break readFromDB + + statsRlp >> network.messageStats + if network.messageStats.len <= les.messages[^1].id: + notice "Found incomplete LES stats record" + break readFromDB + + return true + + except RlpError: + error "Error while loading LES message stats", + err = getCurrentExceptionMsg() + + newSeq(network.messageStats, les.messages[^1].id + 1) + return false + +proc update(s: var FlowControlState, t: LesTime) = + let dt = max(t - s.lastUpdate, LesTime(0)) + + s.bufValue = min( + s.bufValue + s.minRecharge * dt, + s.bufLimit) + + s.lastUpdate = t + +proc init(s: var FlowControlState, + bufLimit: BufValueInt, minRecharge: int, t: LesTime) = + s.bufValue = bufLimit + s.bufLimit = bufLimit + s.minRecharge = minRecharge + s.lastUpdate = t + +func canMakeRequest(s: FlowControlState, + maxCost: ReqCostInt): (LesTime, float64) = + ## Returns the required waiting time before sending a request and + ## the estimated buffer level afterwards (as a fraction of the limit) + const safetyMargin = 50 + + var maxCost = min( + maxCost + safetyMargin * s.minRecharge, + s.bufLimit) + + if s.bufValue >= maxCost: + result[1] = float64(s.bufValue - maxCost) / float64(s.bufLimit) + else: + result[0] = (maxCost - s.bufValue) / s.minRecharge + +func canServeRequest(srv: LesNetwork): bool = + result = srv.reqCount < srv.maxReqCount and + srv.reqCostSum < srv.maxReqCostSum + +proc rechargeReqCost(peer: LesPeer, t: LesTime) = + let dt = t - peer.lastRechargeTime + peer.reqCostVal += peer.reqCostGradient * dt / rechargingScale + peer.lastRechargeTime = t + if peer.isRecharging and t >= peer.rechargingEndsAt: + peer.isRecharging = false + peer.reqCostGradient = 0 + peer.reqCostVal = 0 + +proc updateRechargingParams(peer: LesPeer, network: LesNetwork) = + peer.reqCostGradient = 0 + if peer.reqCount > 0: + peer.reqCostGradient = rechargingScale / network.reqCount + + if peer.isRecharging: + peer.reqCostGradient = (network.rechargingRate * peer.rechargingPower / + network.totalRechargingPower ) + + peer.rechargingEndsAt = peer.lastRechargeTime + + LesTime(peer.reqCostVal * rechargingScale / + -peer.reqCostGradient ) + +proc trackRequests(network: LesNetwork, peer: LesPeer, reqCountChange: int) = + peer.reqCount += reqCountChange + network.reqCount += reqCountChange + + doAssert peer.reqCount >= 0 and network.reqCount >= 0 + + if peer.reqCount == 0: + # All requests have been finished. Start recharging. + peer.isRecharging = true + network.totalRechargingPower += peer.rechargingPower + elif peer.reqCount == reqCountChange and peer.isRecharging: + # `peer.reqCount` must have been 0 for the condition above to hold. + # This is a transition from recharging to serving state. + peer.isRecharging = false + network.totalRechargingPower -= peer.rechargingPower + peer.startReqCostVal = peer.reqCostVal + + updateRechargingParams peer, network + +proc updateFlowControl(network: LesNetwork, t: LesTime) = + while true: + var firstTime = t + for peer in network.peers: + # TODO: perhaps use a bin heap here + if peer.isRecharging and peer.rechargingEndsAt < firstTime: + firstTime = peer.rechargingEndsAt + + let rechargingEndedForSomePeer = firstTime < t + + network.reqCostSum = 0 + for peer in network.peers: + peer.rechargeReqCost firstTime + network.reqCostSum += peer.reqCostVal + + if rechargingEndedForSomePeer: + for peer in network.peers: + if peer.isRecharging: + updateRechargingParams peer, network + else: + network.lastUpdate = t + return + +proc endPendingRequest*(network: LesNetwork, peer: LesPeer, t: LesTime) = + if peer.reqCount > 0: + network.updateFlowControl t + network.trackRequests peer, -1 + network.updateFlowControl t + +proc enlistInFlowControl*(network: LesNetwork, + peer: LesPeer, + peerRechargingPower = 100) = + let t = now() + + assert peer.isServer or peer.isClient + # Each Peer must be potential communication partner for us. + # There will be useless peers on the network, but the logic + # should make sure to disconnect them earlier in `onPeerConnected`. + + if peer.isServer: + peer.localFlowState.init network.bufferLimit, network.minRechargingRate, t + peer.pendingReqs = initTable[int, ReqCostInt]() + + if peer.isClient: + peer.remoteFlowState.init network.bufferLimit, network.minRechargingRate, t + peer.lastRechargeTime = t + peer.rechargingEndsAt = t + peer.rechargingPower = peerRechargingPower + + network.updateFlowControl t + +proc delistFromFlowControl*(network: LesNetwork, peer: LesPeer) = + let t = now() + + # XXX: perhaps this is not safe with our reqCount logic. + # The original code may depend on the binarity of the `serving` flag. + network.endPendingRequest peer, t + network.updateFlowControl t + +proc initFlowControl*(network: LesNetwork, les: ProtocolInfo, + maxReqCount, maxReqCostSum, reqCostTarget: int, + db: AbstractChainDb = nil) = + network.rechargingRate = (rechargingScale * rechargingScale) / + (100 * rechargingScale / reqCostTarget - rechargingScale) + network.maxReqCount = maxReqCount + network.maxReqCostSum = maxReqCostSum + + if not network.loadMessageStats(les, db): + warn "Failed to load persisted LES message stats. " & + "Flow control will be re-initilized." + +proc canMakeRequest(peer: var LesPeer, maxCost: int): (LesTime, float64) = + peer.localFlowState.update now() + return peer.localFlowState.canMakeRequest(maxCost) + +template getRequestCost(peer: LesPeer, localOrRemote: untyped, + msgId, costQuantity: int): ReqCostInt = + template msgCostInfo: untyped = peer.`localOrRemote ReqCosts`[msgId] + + min(msgCostInfo.baseCost + msgCostInfo.reqCost * costQuantity, + peer.`localOrRemote FlowState`.bufLimit) + +proc trackOutgoingRequest*(network: LesNetwork, peer: LesPeer, + msgId, reqId, costQuantity: int) = + let maxCost = peer.getRequestCost(local, msgId, costQuantity) + + peer.localFlowState.bufValue -= maxCost + peer.pendingReqsCost += maxCost + peer.pendingReqs[reqId] = peer.pendingReqsCost + +proc trackIncomingResponse*(peer: LesPeer, reqId: int, bv: BufValueInt) = + let bv = min(bv, peer.localFlowState.bufLimit) + if not peer.pendingReqs.hasKey(reqId): + return + + let costsSumAtSending = peer.pendingReqs.pop(reqId) + let costsSumChange = peer.pendingReqsCost - costsSumAtSending + + peer.localFlowState.bufValue = if bv > costsSumChange: bv - costsSumChange + else: 0 + peer.localFlowState.lastUpdate = now() + +proc acceptRequest*(network: LesNetwork, peer: LesPeer, + msgId, costQuantity: int): Future[bool] {.async.} = + let t = now() + let reqCost = peer.getRequestCost(remote, msgId, costQuantity) + + peer.remoteFlowState.update t + network.updateFlowControl t + + while not network.canServeRequest: + await sleepAsync(10) + + if peer notin network.peers: + # The peer was disconnected or the network + # was shut down while we waited + return false + + network.trackRequests peer, +1 + network.updateFlowControl network.lastUpdate + + if reqCost > peer.remoteFlowState.bufValue: + error "LES peer sent request too early", + recharge = (reqCost - peer.remoteFlowState.bufValue) * rechargingScale / + peer.remoteFlowState.minRecharge + return false + + return true + +proc bufValueAfterRequest*(network: LesNetwork, peer: LesPeer, + msgId: int, quantity: int): BufValueInt = + let t = now() + let costs = peer.remoteReqCosts[msgId] + var reqCost = costs.baseCost + quantity * costs.reqCost + + peer.remoteFlowState.update t + peer.remoteFlowState.bufValue -= reqCost + + network.endPendingRequest peer, t + + let curReqCost = peer.reqCostVal + if curReqCost < peer.remoteFlowState.bufLimit: + let bv = peer.remoteFlowState.bufLimit - curReqCost + if bv > peer.remoteFlowState.bufValue: + peer.remoteFlowState.bufValue = bv + + network.messageStats[msgId].addSample(float64(quantity), + float64(curReqCost - peer.startReqCostVal)) + + return peer.remoteFlowState.bufValue + +when defined(testing): + import unittest, random, ../../rlpx + + proc isMax(s: FlowControlState): bool = + s.bufValue == s.bufLimit + + rlpxProtocol dummyLes(version = 1, shortName = "abc"): + proc a(p: Peer) + proc b(p: Peer) + proc c(p: Peer) + proc d(p: Peer) + proc e(p: Peer) + + template fequals(lhs, rhs: float64, epsilon = 0.0001): bool = + abs(lhs-rhs) < epsilon + + proc tests* = + randomize(3913631) + + suite "les flow control": + suite "running averages": + test "consistent costs": + var s: StatsRunningAverage + for i in 0..100: + s.addSample(5.0, 100.0) + + let (cost, base) = s.calc + + check: + fequals(cost, 100.0) + fequals(base, 0.0) + + test "randomized averages": + proc performTest(qBase, qRandom: int, cBase, cRandom: float64) = + var + s: StatsRunningAverage + expectedFinalCost = cBase + cRandom / 2 + error = expectedFinalCost + + for samples in [100, 1000, 10000]: + for i in 0..samples: + let q = float64(qBase + rand(10)) + s.addSample(q, q * (cBase + rand(cRandom))) + + let (newCost, newBase) = s.calc + # With more samples, our error should decrease, getting + # closer and closer to the average (unless we are already close enough) + let newError = abs(newCost - expectedFinalCost) + check newError < error + error = newError + + # After enough samples we should be very close the the final result + check error < (expectedFinalCost * 0.02) + + performTest(1, 10, 5.0, 100.0) + performTest(1, 4, 200.0, 1000.0) + + suite "buffer value calculations": + type TestReq = object + peer: LesPeer + msgId, quantity: int + accepted: bool + + setup: + var lesNetwork = new LesNetwork + lesNetwork.peers = initSet[LesPeer]() + lesNetwork.initFlowControl(dummyLes.protocolInfo, + reqCostTarget = 300, + maxReqCount = 5, + maxReqCostSum = 1000) + + for i in 0 ..< lesNetwork.messageStats.len: + lesNetwork.messageStats[i].addSample(1.0, float(i) * 100.0) + + var client = new LesPeer + client.isClient = true + + var server = new LesPeer + server.isServer = true + + var clientServer = new LesPeer + clientServer.isClient = true + clientServer.isServer = true + + var client2 = new LesPeer + client2.isClient = true + + var client3 = new LesPeer + client3.isClient = true + + var bv: BufValueInt + + template enlist(peer: LesPeer) {.dirty.} = + let reqCosts = currentRequestsCosts(lesNetwork, dummyLes.protocolInfo) + peer.remoteReqCosts = reqCosts + peer.localReqCosts = reqCosts + lesNetwork.peers.incl peer + lesNetwork.enlistInFlowControl peer + + template startReq(p: LesPeer, msg, q: int): TestReq = + var req: TestReq + req.peer = p + req.msgId = msg + req.quantity = q + req.accepted = waitFor lesNetwork.acceptRequest(p, msg, q) + req + + template endReq(req: TestReq): BufValueInt = + bufValueAfterRequest(lesNetwork, req.peer, req.msgId, req.quantity) + + test "single peer recharging": + lesNetwork.bufferLimit = 1000 + lesNetwork.minRechargingRate = 100 + + enlist client + + check: + client.remoteFlowState.isMax + client.rechargingPower > 0 + + advanceTime 100 + + let r1 = client.startReq(0, 100) + check r1.accepted + check client.isRecharging == false + + advanceTime 50 + + let r2 = client.startReq(1, 1) + check r2.accepted + check client.isRecharging == false + + advanceTime 25 + bv = endReq r2 + check client.isRecharging == false + + advanceTime 130 + bv = endReq r1 + check client.isRecharging == true + + advanceTime 300 + lesNetwork.updateFlowControl now() + + check: + client.isRecharging == false + client.remoteFlowState.isMax + diff --git a/eth_p2p/rlpx_protocols/les/private/les_types.nim b/eth_p2p/rlpx_protocols/les/private/les_types.nim new file mode 100644 index 0000000..0b40717 --- /dev/null +++ b/eth_p2p/rlpx_protocols/les/private/les_types.nim @@ -0,0 +1,113 @@ +import + hashes, tables, sets, + package_visible_types, + eth_common/eth_types + +packageTypes: + type + AnnounceType* = enum + None, + Simple, + Signed, + Unspecified + + ReqCostInfo = object + msgId: int + baseCost, reqCost: ReqCostInt + + FlowControlState = object + bufValue, bufLimit: int + minRecharge: int + lastUpdate: LesTime + + StatsRunningAverage = object + sumX, sumY, sumXX, sumXY: float64 + count: int + + LesPeer* = ref object + isServer*: bool + isClient*: bool + announceType*: AnnounceType + + bestDifficulty*: DifficultyInt + bestBlockHash*: KeccakHash + bestBlockNumber*: BlockNumber + + hasChainSince: HashOrNum + hasStateSince: HashOrNum + relaysTransactions: bool + + # The variables below are used to implement the flow control + # mechanisms of LES from our point of view as a server. + # They describe how much load has been generated by this + # particular peer. + reqCount: int # How many outstanding requests are there? + # + rechargingPower: int # Do we give this peer any extra priority + # (implemented as a faster recharning rate) + # 100 is the default. You can go higher and lower. + # + isRecharging: bool # This is true while the peer is not making + # any requests + # + reqCostGradient: int # Measures the speed of recharging or accumulating + # "requests cost" at any given moment. + # + reqCostVal: int # The accumulated "requests cost" + # + rechargingEndsAt: int # When will recharging end? + # (the buffer of the Peer will be fully restored) + # + lastRechargeTime: LesTime # When did we last update the recharging parameters + # + startReqCostVal: int # TODO + + remoteFlowState: FlowControlState + remoteReqCosts: seq[ReqCostInfo] + + # The next variables are used to limit ourselves as a client in order to + # not violate the control-flow requirements of the remote LES server. + + pendingReqs: Table[int, ReqCostInt] + pendingReqsCost: int + + localFlowState: FlowControlState + localReqCosts: seq[ReqCostInfo] + + LesNetwork* = ref object + peers: HashSet[LesPeer] + messageStats: seq[StatsRunningAverage] + ourAnnounceType*: AnnounceType + + # The fields below are relevant when serving data. + bufferLimit: int + minRechargingRate: int + + reqCostSum, maxReqCostSum: ReqCostInt + reqCount, maxReqCount: int + sumWeigth: int + + rechargingRate: int + totalRechargedUnits: int + totalRechargingPower: int + + lastUpdate: LesTime + + KeyValuePair = object + key: string + value: Blob + + HandshakeError = object of Exception + + LesTime = int # this is in milliseconds + BufValueInt = int + ReqCostInt = int + +template hash*(peer: LesPeer): Hash = hash(cast[pointer](peer)) + +template areWeServingData*(network: LesNetwork): bool = + network.maxReqCount != 0 + +template areWeRequestingData*(network: LesNetwork): bool = + network.ourAnnounceType != AnnounceType.Unspecified + diff --git a/eth_p2p/rlpx_protocols/les_protocol.nim b/eth_p2p/rlpx_protocols/les_protocol.nim index f0d34c9..e5a13e2 100644 --- a/eth_p2p/rlpx_protocols/les_protocol.nim +++ b/eth_p2p/rlpx_protocols/les_protocol.nim @@ -9,64 +9,26 @@ # import - times, - chronicles, asyncdispatch2, rlp, eth_common/eth_types, - ../../eth_p2p + times, tables, options, sets, hashes, strutils, macros, + chronicles, asyncdispatch2, nimcrypto/[keccak, hash], + rlp, eth_common/eth_types, eth_keys, + ../rlpx, ../kademlia, ../private/types, ../blockchain_utils, + les/private/les_types, les/flow_control -type - ProofRequest* = object - blockHash*: KeccakHash - accountKey*: Blob - key*: Blob - fromLevel*: uint - - HeaderProofRequest* = object - chtNumber*: uint - blockNumber*: uint - fromLevel*: uint - - ContractCodeRequest* = object - blockHash*: KeccakHash - key*: EthAddress - - HelperTrieProofRequest* = object - subType*: uint - sectionIdx*: uint - key*: Blob - fromLevel*: uint - auxReq*: uint - - TransactionStatus* = enum - Unknown, - Queued, - Pending, - Included, - Error - - TransactionStatusMsg* = object - status*: TransactionStatus - data*: Blob - - PeerState = object - buffer: int - lastRequestTime: float - reportedTotalDifficulty: DifficultyInt - - KeyValuePair = object - key: string - value: Blob +les_types.forwardPublicTypes const + lesVersion = 2'u maxHeadersFetch = 192 maxBodiesFetch = 32 maxReceiptsFetch = 128 maxCodeFetch = 64 maxProofsFetch = 64 maxHeaderProofsFetch = 64 + maxTransactionsFetch = 64 -# Handshake properties: -# https://github.com/zsfelfoldi/go-ethereum/wiki/Light-Ethereum-Subprotocol-(LES) -const + # Handshake properties: + # https://github.com/zsfelfoldi/go-ethereum/wiki/Light-Ethereum-Subprotocol-(LES) keyProtocolVersion = "protocolVersion" ## P: is 1 for the LPV1 protocol version. @@ -110,98 +72,393 @@ const ## see Client Side Flow Control: ## https://github.com/zsfelfoldi/go-ethereum/wiki/Client-Side-Flow-Control-model-for-the-LES-protocol -const - rechargeRate = 0.3 + keyAnnounceType = "announceType" + keyAnnounceSignature = "sign" -proc getPeerWithNewestChain(pool: PeerPool): Peer = - discard +proc initProtocolState(network: LesNetwork, node: EthereumNode) = + network.peers = initSet[LesPeer]() -rlpxProtocol les, 2: +proc addPeer(network: LesNetwork, peer: LesPeer) = + network.enlistInFlowControl peer + network.peers.incl peer - type State = PeerState +proc removePeer(network: LesNetwork, peer: LesPeer) = + network.delistFromFlowControl peer + network.peers.excl peer + +template costQuantity(quantityExpr: int, max: int) {.pragma.} + +proc getCostQuantity(fn: NimNode): tuple[quantityExpr, maxQuantity: NimNode] = + # XXX: `getCustomPragmaVal` doesn't work yet on regular nnkProcDef nodes + # (TODO: file as an issue) + let p = fn.pragma + assert p.kind == nnkPragma and p.len > 0 and $p[0][0] == "costQuantity" + + result.quantityExpr = p[0][1] + result.maxQuantity= p[0][2] + + if result.maxQuantity.kind == nnkExprEqExpr: + result.maxQuantity = result.maxQuantity[1] + +macro outgoingRequestDecorator(n: untyped): untyped = + result = n + let (costQuantity, maxQuantity) = n.getCostQuantity + + result.body.add quote do: + trackOutgoingRequest(msgRecipient.networkState(les), + msgRecipient.state(les), + perProtocolMsgId, reqId, `costQuantity`) + # echo result.repr + +macro incomingResponseDecorator(n: untyped): untyped = + result = n + + let trackingCall = quote do: + trackIncomingResponse(msgSender.state(les), reqId, msg.bufValue) + + result.body.insert(n.body.len - 1, trackingCall) + # echo result.repr + +macro incomingRequestDecorator(n: untyped): untyped = + result = n + let (costQuantity, maxQuantity) = n.getCostQuantity + + template acceptStep(quantityExpr, maxQuantity) {.dirty.} = + let requestCostQuantity = quantityExpr + if requestCostQuantity > maxQuantity: + await peer.disconnect(BreachOfProtocol) + return + + let lesPeer = peer.state + let lesNetwork = peer.networkState + + if not await acceptRequest(lesNetwork, lesPeer, + perProtocolMsgId, + requestCostQuantity): return + + result.body.insert(1, getAst(acceptStep(costQuantity, maxQuantity))) + # echo result.repr + +template updateBV: BufValueInt = + bufValueAfterRequest(lesNetwork, lesPeer, + perProtocolMsgId, requestCostQuantity) + +func getValue(values: openarray[KeyValuePair], + key: string, T: typedesc): Option[T] = + for v in values: + if v.key == key: + return some(rlp.decode(v.value, T)) + +func getRequiredValue(values: openarray[KeyValuePair], + key: string, T: typedesc): T = + for v in values: + if v.key == key: + return rlp.decode(v.value, T) + + raise newException(HandshakeError, + "Required handshake field " & key & " missing") + +rlpxProtocol les(version = lesVersion, + peerState = LesPeer, + networkState = LesNetwork, + outgoingRequestDecorator = outgoingRequestDecorator, + incomingRequestDecorator = incomingRequestDecorator, + incomingResponseThunkDecorator = incomingResponseDecorator): ## Handshake ## - proc status(p: Peer, values: openarray[KeyValuePair]) = - discard + proc status(p: Peer, values: openarray[KeyValuePair]) + + onPeerConnected do (peer: Peer): + let + network = peer.network + chain = network.chain + bestBlock = chain.getBestBlockHeader + lesPeer = peer.state + lesNetwork = peer.networkState + + template `=>`(k, v: untyped): untyped = + KeyValuePair.init(key = k, value = rlp.encode(v)) + + var lesProperties = @[ + keyProtocolVersion => lesVersion, + keyNetworkId => network.networkId, + keyHeadTotalDifficulty => bestBlock.difficulty, + keyHeadHash => bestBlock.blockHash, + keyHeadNumber => bestBlock.blockNumber, + keyGenesisHash => chain.genesisHash + ] + + lesPeer.remoteReqCosts = currentRequestsCosts(lesNetwork, les.protocolInfo) + + if lesNetwork.areWeServingData: + lesProperties.add [ + # keyServeHeaders => nil, + keyServeChainSince => 0, + keyServeStateSince => 0, + # keyRelaysTransactions => nil, + keyFlowControlBL => lesNetwork.bufferLimit, + keyFlowControlMRR => lesNetwork.minRechargingRate, + keyFlowControlMRC => lesPeer.remoteReqCosts + ] + + if lesNetwork.areWeRequestingData: + lesProperties.add(keyAnnounceType => lesNetwork.ourAnnounceType) + + let + s = await peer.nextMsg(les.status) + peerNetworkId = s.values.getRequiredValue(keyNetworkId, uint) + peerGenesisHash = s.values.getRequiredValue(keyGenesisHash, KeccakHash) + peerLesVersion = s.values.getRequiredValue(keyProtocolVersion, uint) + + template requireCompatibility(peerVar, localVar, varName: untyped) = + if localVar != peerVar: + raise newException(HandshakeError, + "Incompatibility detected! $1 mismatch ($2 != $3)" % + [varName, $localVar, $peerVar]) + + requireCompatibility(peerLesVersion, lesVersion, "les version") + requireCompatibility(peerNetworkId, network.networkId, "network id") + requireCompatibility(peerGenesisHash, chain.genesisHash, "genesis hash") + + template `:=`(lhs, key) = + lhs = s.values.getRequiredValue(key, type(lhs)) + + lesPeer.bestBlockHash := keyHeadHash + lesPeer.bestBlockNumber := keyHeadNumber + lesPeer.bestDifficulty := keyHeadTotalDifficulty + + let peerAnnounceType = s.values.getValue(keyAnnounceType, AnnounceType) + if peerAnnounceType.isSome: + lesPeer.isClient = true + lesPeer.announceType = peerAnnounceType.get + else: + lesPeer.announceType = AnnounceType.Simple + lesPeer.hasChainSince := keyServeChainSince + lesPeer.hasStateSince := keyServeStateSince + lesPeer.relaysTransactions := keyRelaysTransactions + lesPeer.localFlowState.bufLimit := keyFlowControlBL + lesPeer.localFlowState.minRecharge := keyFlowControlMRR + lesPeer.localReqCosts := keyFlowControlMRC + + lesNetwork.addPeer lesPeer + + onPeerDisconnected do (peer: Peer, reason: DisconnectionReason) {.gcsafe.}: + peer.networkState.removePeer peer.state ## Header synchronisation ## - proc announce(p: Peer, - headHash: KeccakHash, - headNumber: BlockNumber, - headTotalDifficulty: DifficultyInt, - reorgDepth: BlockNumber, - values: openarray[KeyValuePair], - announceType: uint) = - discard + proc announce( + peer: Peer, + headHash: KeccakHash, + headNumber: BlockNumber, + headTotalDifficulty: DifficultyInt, + reorgDepth: BlockNumber, + values: openarray[KeyValuePair], + announceType: AnnounceType) = + + if peer.state.announceType == AnnounceType.None: + error "unexpected announce message", peer + return + + if announceType == AnnounceType.Signed: + let signature = values.getValue(keyAnnounceSignature, Blob) + if signature.isNone: + error "missing announce signature" + return + let sigHash = keccak256.digest rlp.encodeList(headHash, + headNumber, + headTotalDifficulty) + let signerKey = recoverKeyFromSignature(signature.get.initSignature, + sigHash) + if signerKey.toNodeId != peer.remote.id: + error "invalid announce signature" + # TODO: should we disconnect this peer? + return + + # TODO: handle new block requestResponse: - proc getBlockHeaders(p: Peer, BV: uint, req: BlocksRequest) = - discard + proc getBlockHeaders( + peer: Peer, + req: BlocksRequest) {. + costQuantity(req.maxResults.int, max = maxHeadersFetch).} = - proc blockHeaders(p: Peer, BV: uint, blocks: openarray[BlockHeader]) = - discard + let headers = peer.network.chain.getBlockHeaders(req) + await peer.blockHeaders(reqId, updateBV(), headers) + + proc blockHeaders( + peer: Peer, + bufValue: BufValueInt, + blocks: openarray[BlockHeader]) ## On-damand data retrieval ## requestResponse: - proc getBlockBodies(p: Peer, blocks: openarray[KeccakHash]) = - discard + proc getBlockBodies( + peer: Peer, + blocks: openarray[KeccakHash]) {. + costQuantity(blocks.len, max = maxBodiesFetch).} = - proc blockBodies(p: Peer, BV: uint, bodies: openarray[BlockBody]) = - discard + let blocks = peer.network.chain.getBlockBodies(blocks) + await peer.blockBodies(reqId, updateBV(), blocks) + + proc blockBodies( + peer: Peer, + bufValue: BufValueInt, + bodies: openarray[BlockBody]) requestResponse: - proc getReceipts(p: Peer, hashes: openarray[KeccakHash]) = - discard + proc getReceipts( + peer: Peer, + hashes: openarray[KeccakHash]) + {.costQuantity(hashes.len, max = maxReceiptsFetch).} = - proc receipts(p: Peer, BV: uint, receipts: openarray[Receipt]) = - discard + let receipts = peer.network.chain.getReceipts(hashes) + await peer.receipts(reqId, updateBV(), receipts) + + proc receipts( + peer: Peer, + bufValue: BufValueInt, + receipts: openarray[Receipt]) requestResponse: - proc getProofs(p: Peer, proofs: openarray[ProofRequest]) = - discard + proc getProofs( + peer: Peer, + proofs: openarray[ProofRequest]) {. + costQuantity(proofs.len, max = maxProofsFetch).} = - proc proofs(p: Peer, BV: uint, proofs: openarray[Blob]) = - discard + let proofs = peer.network.chain.getProofs(proofs) + await peer.proofs(reqId, updateBV(), proofs) + + proc proofs( + peer: Peer, + bufValue: BufValueInt, + proofs: openarray[Blob]) requestResponse: - proc getContractCodes(p: Peer, requests: seq[ContractCodeRequest]) = - discard + proc getContractCodes( + peer: Peer, + reqs: seq[ContractCodeRequest]) {. + costQuantity(reqs.len, max = maxCodeFetch).} = - proc contractCodes(p: Peer, BV: uint, results: seq[Blob]) = - discard + let results = peer.network.chain.getContractCodes(reqs) + await peer.contractCodes(reqId, updateBV(), results) + + proc contractCodes( + peer: Peer, + bufValue: BufValueInt, + results: seq[Blob]) nextID 15 requestResponse: - proc getHeaderProofs(p: Peer, requests: openarray[ProofRequest]) = - discard + proc getHeaderProofs( + peer: Peer, + reqs: openarray[ProofRequest]) {. + costQuantity(reqs.len, max = maxHeaderProofsFetch).} = - proc headerProof(p: Peer, BV: uint, proofs: openarray[Blob]) = - discard + let proofs = peer.network.chain.getHeaderProofs(reqs) + await peer.headerProofs(reqId, updateBV(), proofs) + + proc headerProofs( + peer: Peer, + bufValue: BufValueInt, + proofs: openarray[Blob]) requestResponse: - proc getHelperTrieProofs(p: Peer, requests: openarray[HelperTrieProofRequest]) = - discard + proc getHelperTrieProofs( + peer: Peer, + reqs: openarray[HelperTrieProofRequest]) {. + costQuantity(reqs.len, max = maxProofsFetch).} = - proc helperTrieProof(p: Peer, BV: uint, nodes: seq[Blob], auxData: seq[Blob]) = - discard + var nodes, auxData: seq[Blob] + peer.network.chain.getHelperTrieProofs(reqs, nodes, auxData) + await peer.helperTrieProofs(reqId, updateBV(), nodes, auxData) + + proc helperTrieProofs( + peer: Peer, + bufValue: BufValueInt, + nodes: seq[Blob], + auxData: seq[Blob]) ## Transaction relaying and status retrieval ## requestResponse: - proc sendTxV2(p: Peer, transactions: openarray[Transaction]) = - discard + proc sendTxV2( + peer: Peer, + transactions: openarray[Transaction]) {. + costQuantity(transactions.len, max = maxTransactionsFetch).} = - proc getTxStatus(p: Peer, transactions: openarray[Transaction]) = - discard + let chain = peer.network.chain - proc txStatus(p: Peer, BV: uint, transactions: openarray[TransactionStatusMsg]) = - discard + var results: seq[TransactionStatusMsg] + for t in transactions: + let hash = t.rlpHash # TODO: this is not optimal, we can compute + # the hash from the request bytes. + # The RLP module can offer a helper Hashed[T] + # to make this easy. + var s = chain.getTransactionStatus(hash) + if s.status == TransactionStatus.Unknown: + chain.addTransactions([t]) + s = chain.getTransactionStatus(hash) + + results.add s + + await peer.txStatus(reqId, updateBV(), results) + + proc getTxStatus( + peer: Peer, + transactions: openarray[Transaction]) {. + costQuantity(transactions.len, max = maxTransactionsFetch).} = + + let chain = peer.network.chain + + var results: seq[TransactionStatusMsg] + for t in transactions: + results.add chain.getTransactionStatus(t.rlpHash) + await peer.txStatus(reqId, updateBV(), results) + + proc txStatus( + peer: Peer, + bufValue: BufValueInt, + transactions: openarray[TransactionStatusMsg]) + +proc configureLes*(node: EthereumNode, + # Client options: + announceType = AnnounceType.Simple, + # Server options. + # The zero default values indicate that the + # LES server will be deactivated. + maxReqCount = 0, + maxReqCostSum = 0, + reqCostTarget = 0) = + + doAssert announceType != AnnounceType.Unspecified or maxReqCount > 0 + + var lesNetwork = node.protocolState(les) + lesNetwork.ourAnnounceType = announceType + initFlowControl(lesNetwork, les.protocolInfo, + maxReqCount, maxReqCostSum, reqCostTarget, + node.chain) + +proc configureLesServer*(node: EthereumNode, + # Client options: + announceType = AnnounceType.Unspecified, + # Server options. + # The zero default values indicate that the + # LES server will be deactivated. + maxReqCount = 0, + maxReqCostSum = 0, + reqCostTarget = 0) = + ## This is similar to `configureLes`, but with default parameter + ## values appropriate for a server. + node.configureLes(announceType, maxReqCount, maxReqCostSum, reqCostTarget) + +proc persistLesMessageStats*(node: EthereumNode) = + persistMessageStats(node.chain, node.protocolState(les)) diff --git a/tests/all_tests.nim b/tests/all_tests.nim new file mode 100644 index 0000000..7fa93e0 --- /dev/null +++ b/tests/all_tests.nim @@ -0,0 +1,4 @@ +import + testecies, testauth, testcrypt, + les/test_flow_control + diff --git a/tests/les/test_flow_control.nim b/tests/les/test_flow_control.nim new file mode 100644 index 0000000..a40d104 --- /dev/null +++ b/tests/les/test_flow_control.nim @@ -0,0 +1,5 @@ +import + eth_p2p/rlpx_protocols/les/flow_control + +flow_control.tests() + diff --git a/tests/nim.cfg b/tests/nim.cfg new file mode 100644 index 0000000..71c4a56 --- /dev/null +++ b/tests/nim.cfg @@ -0,0 +1 @@ +d:testing diff --git a/tests/tserver.nim b/tests/tserver.nim index 3d97194..b55cfe2 100644 --- a/tests/tserver.nim +++ b/tests/tserver.nim @@ -7,31 +7,127 @@ # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) -import sequtils -import eth_keys, asyncdispatch2 -import eth_p2p +import + sequtils, strformat, options, unittest, + chronicles, asyncdispatch2, rlp, eth_keys, + eth_p2p, eth_p2p/mock_peers -const clientId = "nim-eth-p2p/0.0.1" +const + clientId = "nim-eth-p2p/0.0.1" -rlpxProtocol dmy, 1: # Rlpx would be useless with no subprotocols. So we define a dummy proto - proc foo(peer: Peer) +type + AbcPeer = ref object + peerName: string + lastResponse: string + + XyzPeer = ref object + messages: int + + AbcNetwork = ref object + peers: seq[string] + +rlpxProtocol abc(version = 1, + peerState = AbcPeer, + networkState = AbcNetwork, + timeout = 100): + + onPeerConnected do (peer: Peer): + await peer.hi "Bob" + let response = await peer.nextMsg(abc.hi) + peer.networkState.peers.add response.name + + onPeerDisconnected do (peer: Peer, reason: DisconnectionReason): + echo "peer disconnected", peer + + requestResponse: + proc abcReq(p: Peer, n: int) = + echo "got req ", n + await p.abcRes(reqId, &"response to #{n}") + + proc abcRes(p: Peer, data: string) = + echo "got response ", data + + proc hi(p: Peer, name: string) = + echo "got hi from ", name + p.state.peerName = name + let query = 123 + echo "sending req #", query + var r = await p.abcReq(query) + if r.isSome: + p.state.lastResponse = r.get.data + else: + p.state.lastResponse = "timeout" + +rlpxProtocol xyz(version = 1, + peerState = XyzPeer, + useRequestIds = false, + timeout = 100): + + proc foo(p: Peer, s: string, a, z: int) = + p.state.messages += 1 + if p.supports(abc): + echo p.state(abc).peerName + + proc bar(p: Peer, i: int, s: string) + + requestResponse: + proc xyzReq(p: Peer, n: int, timeout = 3000) = + echo "got req ", n + + proc xyzRes(p: Peer, data: string) = + echo "got response ", data + +proc defaultTestingHandshake(_: type abc): abc.hi = + result.name = "John Doe" proc localAddress(port: int): Address = let port = Port(port) result = Address(udpPort: port, tcpPort: port, ip: parseIpAddress("127.0.0.1")) -proc test() {.async.} = - let node1Keys = newKeyPair() - let node1Address = localAddress(30303) - var node1 = newEthereumNode(node1Keys, node1Address, 1, nil) - node1.startListening() +template asyncTest(name, body: untyped) = + test name: + proc scenario {.async.} = body + waitFor scenario() - let node2Keys = newKeyPair() - var node2 = newEthereumNode(node2Keys, localAddress(30304), 1, nil) +asyncTest "network with 3 peers using custom protocols": + let localKeys = newKeyPair() + let localAddress = localAddress(30303) + var localNode = newEthereumNode(localKeys, localAddress, 1, nil) + localNode.initProtocolStates() + localNode.startListening() - let node1AsRemote = newNode(initENode(node1Keys.pubKey, node1Address)) - let peer = await node2.rlpxConnect(node1AsRemote) + var mock1 = newMockPeer do (m: MockConf): + m.addHandshake abc.hi(name: "Alice") - doAssert(not peer.isNil) + m.expect(abc.abcReq) do (peer: Peer, data: Rlp): + let reqId = data.readReqId() + await peer.abcRes(reqId, "mock response") + await sleepAsync(100) + let r = await peer.abcReq(1) + assert r.get.data == "response to #1" + + m.expect(abc.abcRes) + + var mock2 = newMockPeer do (m: MockConf): + m.addCapability xyz + m.addCapability abc + + m.expect(abc.abcReq) # we'll let this one time out + + m.expect(xyz.xyzReq) do (peer: Peer): + echo "got xyz req" + await peer.xyzRes("mock peer data") + + discard await mock1.rlpxConnect(localNode) + let mock2Connection = await localNode.rlpxConnect(mock2) + + let r = await mock2Connection.xyzReq(10) + check r.get.data == "mock peer data" + + let abcNetState = localNode.protocolState(abc) + + check: + abcNetState.peers.len == 2 + "Alice" in abcNetState.peers + "John Doe" in abcNetState.peers -waitFor test()