mirror of https://github.com/status-im/nim-eth.git
refactor p2pProtocol internals
Nim devel brach(1.7.1) introduce gc=orc as default mode. Because the p2p protocol using unsafe pointer operations for it's ProtocolInfo and using global variables scattered around, the orc mistakenly(or maybe correctly) crash the protocol.
This commit is contained in:
parent
d238693571
commit
e1bdf1741a
|
@ -106,7 +106,7 @@ proc newEthereumNode*(
|
|||
result.protocolVersion = if useCompression: devp2pSnappyVersion
|
||||
else: devp2pVersion
|
||||
|
||||
result.protocolStates.newSeq allProtocols.len
|
||||
result.protocolStates.newSeq protocolCount()
|
||||
|
||||
result.peerPool = newPeerPool(
|
||||
result, networkId, keys, nil, clientId, minPeers = minPeers)
|
||||
|
@ -114,8 +114,8 @@ proc newEthereumNode*(
|
|||
result.peerPool.discovery = result.discovery
|
||||
|
||||
if addAllCapabilities:
|
||||
for p in allProtocols:
|
||||
result.addCapability(p)
|
||||
for cap in protocols():
|
||||
result.addCapability(cap)
|
||||
|
||||
proc processIncoming(server: StreamServer,
|
||||
remote: StreamTransport): Future[void] {.async, gcsafe.} =
|
||||
|
|
|
@ -1,9 +1,33 @@
|
|||
var
|
||||
gProtocols: seq[ProtocolInfo]
|
||||
let protocolManager = ProtocolManager()
|
||||
|
||||
# The variables above are immutable RTTI information. We need to tell
|
||||
# Nim to not consider them GcSafe violations:
|
||||
template allProtocols*: auto = {.gcsafe.}: gProtocols
|
||||
|
||||
proc registerProtocol*(proto: ProtocolInfo) {.gcsafe.} =
|
||||
{.gcsafe.}:
|
||||
proto.index = protocolManager.protocols.len
|
||||
if proto.name == "p2p":
|
||||
doAssert(proto.index == 0)
|
||||
protocolManager.protocols.add proto
|
||||
|
||||
proc protocolCount*(): int {.gcsafe.} =
|
||||
{.gcsafe.}:
|
||||
protocolManager.protocols.len
|
||||
|
||||
proc getProtocol*(index: int): ProtocolInfo {.gcsafe.} =
|
||||
{.gcsafe.}:
|
||||
protocolManager.protocols[index]
|
||||
|
||||
iterator protocols*(): ProtocolInfo {.gcsafe.} =
|
||||
{.gcsafe.}:
|
||||
for x in protocolManager.protocols:
|
||||
yield x
|
||||
|
||||
template getProtocol*(Protocol: type): ProtocolInfo =
|
||||
getProtocol(Protocol.index)
|
||||
|
||||
template devp2pInfo*(): ProtocolInfo =
|
||||
getProtocol(0)
|
||||
|
||||
proc getState*(peer: Peer, proto: ProtocolInfo): RootRef =
|
||||
peer.protocolStates[proto.index]
|
||||
|
@ -35,9 +59,8 @@ proc initProtocolState*[T](state: T, x: Peer|EthereumNode)
|
|||
proc initProtocolStates(peer: Peer, protocols: openArray[ProtocolInfo])
|
||||
{.raises: [Defect].} =
|
||||
# Initialize all the active protocol states
|
||||
newSeq(peer.protocolStates, allProtocols.len)
|
||||
newSeq(peer.protocolStates, protocolCount())
|
||||
for protocol in protocols:
|
||||
let peerStateInit = protocol.peerStateInitializer
|
||||
if peerStateInit != nil:
|
||||
peer.protocolStates[protocol.index] = peerStateInit(peer)
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
{.push raises: [Defect].}
|
||||
|
||||
import
|
||||
std/[options, sequtils],
|
||||
std/[options, sequtils, macrocache],
|
||||
stew/shims/macros, chronos, faststreams/outputs
|
||||
|
||||
type
|
||||
|
@ -76,7 +76,7 @@ type
|
|||
|
||||
# Cached properties
|
||||
nameIdent*: NimNode
|
||||
protocolInfoVar*: NimNode
|
||||
protocolInfo*: NimNode
|
||||
|
||||
# All messages
|
||||
messages*: seq[Message]
|
||||
|
@ -146,6 +146,9 @@ let
|
|||
PROTO {.compileTime.} = ident "PROTO"
|
||||
MSG {.compileTime.} = ident "MSG"
|
||||
|
||||
const
|
||||
protocolCounter = CacheCounter"protocolCounter"
|
||||
|
||||
template Opt(T): auto = newTree(nnkBracketExpr, Option, T)
|
||||
template Fut(T): auto = newTree(nnkBracketExpr, Future, T)
|
||||
|
||||
|
@ -311,7 +314,7 @@ proc init*(T: type P2PProtocol, backendFactory: BackendFactory,
|
|||
PeerStateType: verifyStateType peerState,
|
||||
NetworkStateType: verifyStateType networkState,
|
||||
nameIdent: ident(name),
|
||||
protocolInfoVar: ident(name & "Protocol"),
|
||||
protocolInfo: newCall(ident("protocolInfo"), ident(name)),
|
||||
outSendProcs: newStmtList(),
|
||||
outRecvProcs: newStmtList(),
|
||||
outProcRegistrations: newStmtList())
|
||||
|
@ -343,7 +346,7 @@ proc augmentUserHandler(p: P2PProtocol, userHandlerProc: NimNode, msgId = -1) =
|
|||
var
|
||||
getState = ident"getState"
|
||||
getNetworkState = ident"getNetworkState"
|
||||
protocolInfoVar = p.protocolInfoVar
|
||||
protocolInfo = p.protocolInfo
|
||||
protocolNameIdent = p.nameIdent
|
||||
PeerType = p.backend.PeerType
|
||||
PeerStateType = p.PeerStateType
|
||||
|
@ -370,12 +373,12 @@ proc augmentUserHandler(p: P2PProtocol, userHandlerProc: NimNode, msgId = -1) =
|
|||
if PeerStateType != nil:
|
||||
prelude.add quote do:
|
||||
template state(`peerVar`: `PeerType`): `PeerStateType` =
|
||||
cast[`PeerStateType`](`getState`(`peerVar`, `protocolInfoVar`))
|
||||
`PeerStateType`(`getState`(`peerVar`, `protocolInfo`))
|
||||
|
||||
if NetworkStateType != nil:
|
||||
prelude.add quote do:
|
||||
template networkState(`peerVar`: `PeerType`): `NetworkStateType` =
|
||||
cast[`NetworkStateType`](`getNetworkState`(`peerVar`.network, `protocolInfoVar`))
|
||||
`NetworkStateType`(`getNetworkState`(`peerVar`.network, `protocolInfo`))
|
||||
|
||||
proc addPreludeDefs*(userHandlerProc: NimNode, definitions: NimNode) =
|
||||
userHandlerProc.body[0].add definitions
|
||||
|
@ -699,7 +702,7 @@ proc useStandardBody*(sendProc: SendProc,
|
|||
newStmtList()
|
||||
else:
|
||||
logSentMsgFields(recipient,
|
||||
msg.protocol.protocolInfoVar,
|
||||
msg.protocol.protocolInfo,
|
||||
$msg.ident,
|
||||
sendProc.msgParams)
|
||||
|
||||
|
@ -895,16 +898,24 @@ proc processProtocolBody*(p: P2PProtocol, protocolBody: NimNode) =
|
|||
|
||||
proc genTypeSection*(p: P2PProtocol): NimNode =
|
||||
var
|
||||
protocolIdx = protocolCounter.value
|
||||
protocolName = p.nameIdent
|
||||
peerState = p.PeerStateType
|
||||
networkState= p.NetworkStateType
|
||||
|
||||
protocolCounter.inc
|
||||
result = newStmtList()
|
||||
result.add quote do:
|
||||
# Create a type acting as a pseudo-object representing the protocol
|
||||
# (e.g. p2p)
|
||||
type `protocolName`* = object
|
||||
|
||||
# The protocol run-time index is available as a pseudo-field
|
||||
# (e.g. `p2p.index`)
|
||||
template index*(`PROTO`: type `protocolName`): auto = `protocolIdx`
|
||||
template protocolInfo*(`PROTO`: type `protocolName`): auto =
|
||||
getProtocol(`protocolIdx`)
|
||||
|
||||
if peerState != nil:
|
||||
result.add quote do:
|
||||
template State*(`PROTO`: type `protocolName`): type = `peerState`
|
||||
|
@ -949,33 +960,29 @@ proc genCode*(p: P2PProtocol): NimNode =
|
|||
result.add p.genTypeSection()
|
||||
|
||||
let
|
||||
protocolInfoVar = p.protocolInfoVar
|
||||
protocolInfoVarObj = ident($protocolInfoVar & "Obj")
|
||||
protocolName = p.nameIdent
|
||||
protocolInit = p.backend.implementProtocolInit(p)
|
||||
|
||||
result.add quote do:
|
||||
# One global variable per protocol holds the protocol run-time data
|
||||
var `protocolInfoVarObj` = `protocolInit`
|
||||
var `protocolInfoVar` = addr `protocolInfoVarObj`
|
||||
|
||||
# The protocol run-time data is available as a pseudo-field
|
||||
# (e.g. `p2p.protocolInfo`)
|
||||
template protocolInfo*(`PROTO`: type `protocolName`): auto = `protocolInfoVar`
|
||||
protocolReg = ident($p.nameIdent & "Registration")
|
||||
regBody = newStmtList()
|
||||
|
||||
result.add p.outSendProcs,
|
||||
p.outRecvProcs,
|
||||
p.outProcRegistrations
|
||||
p.outRecvProcs
|
||||
|
||||
if p.onPeerConnected != nil: result.add p.onPeerConnected
|
||||
if p.onPeerDisconnected != nil: result.add p.onPeerDisconnected
|
||||
|
||||
result.add newCall(p.backend.setEventHandlers,
|
||||
protocolInfoVar,
|
||||
nameOrNil p.onPeerConnected,
|
||||
nameOrNil p.onPeerDisconnected)
|
||||
regBody.add newCall(p.backend.setEventHandlers,
|
||||
protocolVar,
|
||||
nameOrNil p.onPeerConnected,
|
||||
nameOrNil p.onPeerDisconnected)
|
||||
|
||||
result.add newCall(p.backend.registerProtocol, protocolInfoVar)
|
||||
regBody.add p.outProcRegistrations
|
||||
regBody.add newCall(p.backend.registerProtocol, protocolVar)
|
||||
|
||||
result.add quote do:
|
||||
proc `protocolReg`() {.raises: [RlpError, Defect].} =
|
||||
let `protocolVar` = `protocolInit`
|
||||
`regBody`
|
||||
`protocolReg`()
|
||||
|
||||
macro emitForSingleBackend(
|
||||
name: static[string],
|
||||
|
|
|
@ -93,7 +93,10 @@ type
|
|||
## Quasy-private types. Use at your own risk.
|
||||
##
|
||||
|
||||
ProtocolInfoObj* = object
|
||||
ProtocolManager* = ref object
|
||||
protocols*: seq[ProtocolInfo]
|
||||
|
||||
ProtocolInfo* = ref object
|
||||
name*: string
|
||||
version*: int
|
||||
messages*: seq[MessageInfo]
|
||||
|
@ -106,9 +109,7 @@ type
|
|||
handshake*: HandshakeStep
|
||||
disconnectHandler*: DisconnectionHandler
|
||||
|
||||
ProtocolInfo* = ptr ProtocolInfoObj
|
||||
|
||||
MessageInfo* = object
|
||||
MessageInfo* = ref object
|
||||
id*: int
|
||||
name*: string
|
||||
|
||||
|
@ -132,7 +133,7 @@ type
|
|||
# `messages` holds a mapping from valid message IDs to their handler procs.
|
||||
#
|
||||
protocolOffsets*: seq[int]
|
||||
messages*: seq[ptr MessageInfo]
|
||||
messages*: seq[MessageInfo]
|
||||
activeProtocols*: seq[ProtocolInfo]
|
||||
|
||||
##
|
||||
|
|
|
@ -192,9 +192,6 @@ proc handshakeImpl[T](peer: Peer,
|
|||
else:
|
||||
return responseFut.read
|
||||
|
||||
var gDevp2pInfo: ProtocolInfo
|
||||
template devp2pInfo: auto = {.gcsafe.}: gDevp2pInfo
|
||||
|
||||
# Dispatcher
|
||||
#
|
||||
|
||||
|
@ -220,7 +217,7 @@ proc getDispatcher(node: EthereumNode,
|
|||
# We should be able to find an existing dispatcher without allocating a new one
|
||||
|
||||
new result
|
||||
newSeq(result.protocolOffsets, allProtocols.len)
|
||||
newSeq(result.protocolOffsets, protocolCount())
|
||||
result.protocolOffsets.fill -1
|
||||
|
||||
var nextUserMsgId = 0x10
|
||||
|
@ -237,9 +234,9 @@ proc getDispatcher(node: EthereumNode,
|
|||
|
||||
template copyTo(src, dest; index: int) =
|
||||
for i in 0 ..< src.len:
|
||||
dest[index + i] = addr src[i]
|
||||
dest[index + i] = src[i]
|
||||
|
||||
result.messages = newSeq[ptr MessageInfo](nextUserMsgId)
|
||||
result.messages = newSeq[MessageInfo](nextUserMsgId)
|
||||
devp2pInfo.messages.copyTo(result.messages, 0)
|
||||
|
||||
for localProtocol in node.protocols:
|
||||
|
@ -262,30 +259,35 @@ proc getMsgName*(peer: Peer, msgId: int): string =
|
|||
of 3: "pong"
|
||||
else: $msgId
|
||||
|
||||
proc getMsgMetadata*(peer: Peer, msgId: int): (ProtocolInfo, ptr MessageInfo) =
|
||||
proc getMsgMetadata*(peer: Peer, msgId: int): (ProtocolInfo, MessageInfo) =
|
||||
doAssert msgId >= 0
|
||||
|
||||
if msgId <= devp2pInfo.messages[^1].id:
|
||||
return (devp2pInfo, addr devp2pInfo.messages[msgId])
|
||||
let dpInfo = devp2pInfo()
|
||||
if msgId <= dpInfo.messages[^1].id:
|
||||
return (dpInfo, dpInfo.messages[msgId])
|
||||
|
||||
if msgId < peer.dispatcher.messages.len:
|
||||
for i in 0 ..< allProtocols.len:
|
||||
let numProtocol = protocolCount()
|
||||
for i in 0 ..< numProtocol:
|
||||
let protocol = getProtocol(i)
|
||||
let offset = peer.dispatcher.protocolOffsets[i]
|
||||
if offset != -1 and
|
||||
offset + allProtocols[i].messages[^1].id >= msgId:
|
||||
return (allProtocols[i], peer.dispatcher.messages[msgId])
|
||||
offset + protocol.messages[^1].id >= msgId:
|
||||
return (protocol, peer.dispatcher.messages[msgId])
|
||||
|
||||
# Protocol info objects
|
||||
#
|
||||
|
||||
proc initProtocol(name: string, version: int,
|
||||
peerInit: PeerStateInitializer,
|
||||
networkInit: NetworkStateInitializer): ProtocolInfoObj =
|
||||
result.name = name
|
||||
result.version = version
|
||||
result.messages = @[]
|
||||
result.peerStateInitializer = peerInit
|
||||
result.networkStateInitializer = networkInit
|
||||
networkInit: NetworkStateInitializer): ProtocolInfo =
|
||||
ProtocolInfo(
|
||||
name : name,
|
||||
version : version,
|
||||
messages: @[],
|
||||
peerStateInitializer: peerInit,
|
||||
networkStateInitializer: networkInit
|
||||
)
|
||||
|
||||
proc setEventHandlers(p: ProtocolInfo,
|
||||
handshake: HandshakeStep,
|
||||
|
@ -321,16 +323,6 @@ proc registerMsg(protocol: ProtocolInfo,
|
|||
requestResolver: requestResolver,
|
||||
nextMsgResolver: nextMsgResolver)
|
||||
|
||||
proc registerProtocol(protocol: ProtocolInfo) =
|
||||
# TODO: This can be done at compile-time in the future
|
||||
if protocol.name != "p2p":
|
||||
let pos = lowerBound(gProtocols, protocol)
|
||||
gProtocols.insert(protocol, pos)
|
||||
for i in 0 ..< gProtocols.len:
|
||||
gProtocols[i].index = i
|
||||
else:
|
||||
gDevp2pInfo = protocol
|
||||
|
||||
# Message composition and encryption
|
||||
#
|
||||
|
||||
|
@ -973,7 +965,7 @@ proc p2pProtocolBackendImpl*(protocol: P2PProtocol): Backend =
|
|||
quote: return `sendCall`
|
||||
|
||||
let perPeerMsgIdValue = if isSubprotocol:
|
||||
newCall(perPeerMsgIdImpl, peerVar, protocol.protocolInfoVar, newLit(msgId))
|
||||
newCall(perPeerMsgIdImpl, peerVar, protocol.protocolInfo, newLit(msgId))
|
||||
else:
|
||||
newLit(msgId)
|
||||
|
||||
|
@ -1009,7 +1001,7 @@ proc p2pProtocolBackendImpl*(protocol: P2PProtocol): Backend =
|
|||
|
||||
protocol.outProcRegistrations.add(
|
||||
newCall(registerMsg,
|
||||
protocol.protocolInfoVar,
|
||||
protocolVar,
|
||||
newLit(msgId),
|
||||
newLit(msgName),
|
||||
thunkName,
|
||||
|
@ -1063,7 +1055,7 @@ proc removePeer(network: EthereumNode, peer: Peer) =
|
|||
|
||||
proc callDisconnectHandlers(peer: Peer, reason: DisconnectionReason):
|
||||
Future[void] {.async.} =
|
||||
var futures = newSeqOfCap[Future[void]](allProtocols.len)
|
||||
var futures = newSeqOfCap[Future[void]](protocolCount())
|
||||
|
||||
for protocol in peer.dispatcher.activeProtocols:
|
||||
if protocol.disconnectHandler != nil:
|
||||
|
@ -1144,7 +1136,7 @@ proc postHelloSteps(peer: Peer, h: DevP2P.hello) {.async.} =
|
|||
# chance to send any initial packages they might require over
|
||||
# the network and to yield on their `nextMsg` waits.
|
||||
#
|
||||
var subProtocolsHandshakes = newSeqOfCap[Future[void]](allProtocols.len)
|
||||
var subProtocolsHandshakes = newSeqOfCap[Future[void]](protocolCount())
|
||||
for protocol in peer.dispatcher.activeProtocols:
|
||||
if protocol.handshake != nil:
|
||||
subProtocolsHandshakes.add((protocol.handshake)(peer))
|
||||
|
|
|
@ -6,7 +6,7 @@ import
|
|||
# real eth protocol implementation is in nimbus-eth1 repo
|
||||
|
||||
type
|
||||
PeerState = ref object
|
||||
PeerState = ref object of RootRef
|
||||
initialized*: bool
|
||||
|
||||
p2pProtocol eth(version = 63,
|
||||
|
|
|
@ -16,9 +16,12 @@ import
|
|||
./p2p_test_helper
|
||||
|
||||
type
|
||||
network = ref object
|
||||
network = ref object of RootRef
|
||||
count*: int
|
||||
|
||||
PeerState = ref object of RootRef
|
||||
status*: string
|
||||
|
||||
p2pProtocol abc(version = 1,
|
||||
rlpxName = "abc",
|
||||
networkState = network):
|
||||
|
@ -33,15 +36,18 @@ p2pProtocol abc(version = 1,
|
|||
|
||||
p2pProtocol xyz(version = 1,
|
||||
rlpxName = "xyz",
|
||||
networkState = network):
|
||||
networkState = network,
|
||||
peerState = PeerState):
|
||||
|
||||
onPeerConnected do (peer: Peer):
|
||||
peer.networkState.count += 1
|
||||
peer.state.status = "connected"
|
||||
|
||||
onPeerDisconnected do (peer: Peer, reason: DisconnectionReason) {.gcsafe.}:
|
||||
peer.networkState.count -= 1
|
||||
if true:
|
||||
raise newException(CatchableError, "Fake xyz exception")
|
||||
peer.state.status = "disconnected"
|
||||
|
||||
p2pProtocol hah(version = 1,
|
||||
rlpxName = "hah",
|
||||
|
@ -67,6 +73,7 @@ suite "Testing protocol handlers":
|
|||
let peer = await node1.rlpxConnect(newNode(node2.toENode()))
|
||||
check:
|
||||
peer.isNil == false
|
||||
peer.state(xyz).status == "connected"
|
||||
|
||||
await peer.disconnect(SubprotocolReason, true)
|
||||
check:
|
||||
|
@ -74,6 +81,7 @@ suite "Testing protocol handlers":
|
|||
# handlers, each handler still ran
|
||||
node1.protocolState(abc).count == 0
|
||||
node1.protocolState(xyz).count == 0
|
||||
peer.state(xyz).status == "connected"
|
||||
|
||||
asyncTest "Failing connection handler":
|
||||
let rng = newRng()
|
||||
|
|
Loading…
Reference in New Issue