diff --git a/eth/p2p.nim b/eth/p2p.nim index fe990e8..0e30a96 100644 --- a/eth/p2p.nim +++ b/eth/p2p.nim @@ -8,7 +8,7 @@ {.push raises: [Defect].} import - std/[tables, algorithm, random], + std/[tables, algorithm, random, typetraits, strutils], chronos, chronos/timer, chronicles, ./keys, ./p2p/private/p2p_types, ./p2p/[kademlia, discovery, enode, peer_pool, rlpx] @@ -16,19 +16,59 @@ import export p2p_types, rlpx, enode, kademlia -proc addCapability*(node: var EthereumNode, p: ProtocolInfo) = +proc addCapability*(node: var EthereumNode, + p: ProtocolInfo, + networkState: RootRef = nil) = doAssert node.connectionState == ConnectionState.None let pos = lowerBound(node.protocols, p, rlpx.cmp) node.protocols.insert(p, pos) node.capabilities.insert(p.asCapability, pos) - if p.networkStateInitializer != nil: + if p.networkStateInitializer != nil and networkState.isNil: node.protocolStates[p.index] = p.networkStateInitializer(node) + if networkState.isNil.not: + node.protocolStates[p.index] = networkState + template addCapability*(node: var EthereumNode, Protocol: type) = addCapability(node, Protocol.protocolInfo) +template addCapability*(node: var EthereumNode, + Protocol: type, + networkState: untyped) = + mixin NetworkState + type + ParamType = type(networkState) + + when ParamType isnot Protocol.NetworkState: + const errMsg = "`$1` is not compatible with `$2`" % [ + name(ParamType), name(Protocol.NetworkState)] + {. error: errMsg .} + + addCapability(node, Protocol.protocolInfo, + cast[RootRef](networkState)) + +proc replaceNetworkState*(node: var EthereumNode, + p: ProtocolInfo, + networkState: RootRef) = + node.protocolStates[p.index] = networkState + +template replaceNetworkState*(node: var EthereumNode, + Protocol: type, + networkState: untyped) = + mixin NetworkState + type + ParamType = type(networkState) + + when ParamType isnot Protocol.NetworkState: + const errMsg = "`$1` is not compatible with `$2`" % [ + name(ParamType), name(Protocol.NetworkState)] + {. error: errMsg .} + + replaceNetworkState(node, Protocol.protocolInfo, + cast[RootRef](networkState)) + proc newEthereumNode*( keys: KeyPair, address: Address, diff --git a/tests/p2p/test_protocol_handlers.nim b/tests/p2p/test_protocol_handlers.nim index 6e609f3..fd35a8d 100644 --- a/tests/p2p/test_protocol_handlers.nim +++ b/tests/p2p/test_protocol_handlers.nim @@ -86,3 +86,9 @@ suite "Testing protocol handlers": peer.isNil == true # To check if the disconnection handler did not run node1.protocolState(hah).count == 0 + + test "Override network state": + let rng = newRng() + var node = setupTestNode(rng, hah) + node.addCapability(hah, network()) + node.replaceNetworkState(hah, network())