diff --git a/eth/p2p/blockchain_sync.nim b/eth/p2p/blockchain_sync.nim index 803e3d1..be14e21 100644 --- a/eth/p2p/blockchain_sync.nim +++ b/eth/p2p/blockchain_sync.nim @@ -331,13 +331,12 @@ proc onPeerDisconnected(ctx: SyncContext, p: Peer) = proc startSync(ctx: SyncContext) = var po: PeerObserver po.onPeerConnected = proc(p: Peer) {.gcsafe.} = - if p.supports(eth): - ctx.onPeerConnected(p) + ctx.onPeerConnected(p) po.onPeerDisconnected = proc(p: Peer) {.gcsafe.} = - if p.supports(eth): - ctx.onPeerDisconnected(p) + ctx.onPeerDisconnected(p) + po.setProtocol eth ctx.peerPool.addObserver(ctx, po) proc findBestPeer(node: EthereumNode): (Peer, DifficultyInt) = diff --git a/eth/p2p/peer_pool.nim b/eth/p2p/peer_pool.nim index 4bb479e..328fdab 100644 --- a/eth/p2p/peer_pool.nim +++ b/eth/p2p/peer_pool.nim @@ -35,7 +35,8 @@ proc addObserver(p: PeerPool, observerId: int, observer: PeerObserver) = p.observers[observerId] = observer if not observer.onPeerConnected.isNil: for peer in p.connectedNodes.values: - observer.onPeerConnected(peer) + if observer.protocol.isNil or peer.supports(observer.protocol): + observer.onPeerConnected(peer) proc delObserver(p: PeerPool, observerId: int) = p.observers.del(observerId) @@ -46,6 +47,9 @@ proc addObserver*(p: PeerPool, observerId: ref, observer: PeerObserver) {.inline proc delObserver*(p: PeerPool, observerId: ref) {.inline.} = p.delObserver(cast[int](observerId)) +template setProtocol*(observer: PeerObserver, Protocol: type) = + observer.protocol = Protocol.protocolInfo + proc stopAllPeers(p: PeerPool) {.async.} = debug "Stopping all peers ..." # TODO: ... @@ -108,7 +112,8 @@ proc addPeer*(pool: PeerPool, peer: Peer): bool = pool.connectedNodes[peer.remote] = peer for o in pool.observers.values: if not o.onPeerConnected.isNil: - o.onPeerConnected(peer) + if o.protocol.isNil or peer.supports(o.protocol): + o.onPeerConnected(peer) return true else: return false diff --git a/eth/p2p/private/p2p_types.nim b/eth/p2p/private/p2p_types.nim index cb0e125..cc06fc2 100644 --- a/eth/p2p/private/p2p_types.nim +++ b/eth/p2p/private/p2p_types.nim @@ -59,6 +59,7 @@ type PeerObserver* = object onPeerConnected*: proc(p: Peer) {.gcsafe.} onPeerDisconnected*: proc(p: Peer) {.gcsafe.} + protocol*: ProtocolInfo Capability* = object name*: string diff --git a/eth/p2p/rlpx.nim b/eth/p2p/rlpx.nim index feebf75..43e2d2e 100644 --- a/eth/p2p/rlpx.nim +++ b/eth/p2p/rlpx.nim @@ -227,9 +227,6 @@ proc registerProtocol(protocol: ProtocolInfo) = # Message composition and encryption # -template protocolOffset(peer: Peer, Protocol: type): int = - peer.dispatcher.protocolOffsets[Protocol.protocolInfo.index] - proc perPeerMsgIdImpl(peer: Peer, proto: ProtocolInfo, msgId: int): int {.inline.} = result = msgId if not peer.dispatcher.isNil: @@ -239,9 +236,12 @@ template getPeer(peer: Peer): auto = peer template getPeer(response: Response): auto = Peer(response) template getPeer(response: ResponseWithId): auto = response.peer +proc supports*(peer: Peer, proto: ProtocolInfo): bool {.inline.} = + peer.dispatcher.protocolOffsets[proto.index] != -1 + proc supports*(peer: Peer, Protocol: type): bool {.inline.} = ## Checks whether a Peer supports a particular protocol - peer.protocolOffset(Protocol) != -1 + peer.supports(Protocol.protocolInfo) template perPeerMsgId(peer: Peer, MsgType: type): int = perPeerMsgIdImpl(peer, MsgType.msgProtocol.protocolInfo, MsgType.msgId) @@ -1124,7 +1124,8 @@ proc removePeer(network: EthereumNode, peer: Peer) = for observer in network.peerPool.observers.values: if not observer.onPeerDisconnected.isNil: - observer.onPeerDisconnected(peer) + if observer.protocol.isNil or peer.supports(observer.protocol): + observer.onPeerDisconnected(peer) proc callDisconnectHandlers(peer: Peer, reason: DisconnectionReason): Future[void] = var futures = newSeqOfCap[Future[void]](allProtocols.len)