refactor p2pProtocol internals

Nim devel brach(1.7.1) introduce gc=orc as default mode.
Because the p2p protocol using unsafe pointer operations
for it's ProtocolInfo and using global variables scattered
around, the orc mistakenly(or maybe correctly) crash the protocol.
This commit is contained in:
jangko 2022-10-13 09:50:49 +07:00
parent d238693571
commit e1bdf1741a
No known key found for this signature in database
GPG Key ID: 31702AE10541E6B9
7 changed files with 106 additions and 75 deletions

View File

@ -106,7 +106,7 @@ proc newEthereumNode*(
result.protocolVersion = if useCompression: devp2pSnappyVersion result.protocolVersion = if useCompression: devp2pSnappyVersion
else: devp2pVersion else: devp2pVersion
result.protocolStates.newSeq allProtocols.len result.protocolStates.newSeq protocolCount()
result.peerPool = newPeerPool( result.peerPool = newPeerPool(
result, networkId, keys, nil, clientId, minPeers = minPeers) result, networkId, keys, nil, clientId, minPeers = minPeers)
@ -114,8 +114,8 @@ proc newEthereumNode*(
result.peerPool.discovery = result.discovery result.peerPool.discovery = result.discovery
if addAllCapabilities: if addAllCapabilities:
for p in allProtocols: for cap in protocols():
result.addCapability(p) result.addCapability(cap)
proc processIncoming(server: StreamServer, proc processIncoming(server: StreamServer,
remote: StreamTransport): Future[void] {.async, gcsafe.} = remote: StreamTransport): Future[void] {.async, gcsafe.} =

View File

@ -1,9 +1,33 @@
var let protocolManager = ProtocolManager()
gProtocols: seq[ProtocolInfo]
# The variables above are immutable RTTI information. We need to tell # The variables above are immutable RTTI information. We need to tell
# Nim to not consider them GcSafe violations: # Nim to not consider them GcSafe violations:
template allProtocols*: auto = {.gcsafe.}: gProtocols
proc registerProtocol*(proto: ProtocolInfo) {.gcsafe.} =
{.gcsafe.}:
proto.index = protocolManager.protocols.len
if proto.name == "p2p":
doAssert(proto.index == 0)
protocolManager.protocols.add proto
proc protocolCount*(): int {.gcsafe.} =
{.gcsafe.}:
protocolManager.protocols.len
proc getProtocol*(index: int): ProtocolInfo {.gcsafe.} =
{.gcsafe.}:
protocolManager.protocols[index]
iterator protocols*(): ProtocolInfo {.gcsafe.} =
{.gcsafe.}:
for x in protocolManager.protocols:
yield x
template getProtocol*(Protocol: type): ProtocolInfo =
getProtocol(Protocol.index)
template devp2pInfo*(): ProtocolInfo =
getProtocol(0)
proc getState*(peer: Peer, proto: ProtocolInfo): RootRef = proc getState*(peer: Peer, proto: ProtocolInfo): RootRef =
peer.protocolStates[proto.index] peer.protocolStates[proto.index]
@ -35,9 +59,8 @@ proc initProtocolState*[T](state: T, x: Peer|EthereumNode)
proc initProtocolStates(peer: Peer, protocols: openArray[ProtocolInfo]) proc initProtocolStates(peer: Peer, protocols: openArray[ProtocolInfo])
{.raises: [Defect].} = {.raises: [Defect].} =
# Initialize all the active protocol states # Initialize all the active protocol states
newSeq(peer.protocolStates, allProtocols.len) newSeq(peer.protocolStates, protocolCount())
for protocol in protocols: for protocol in protocols:
let peerStateInit = protocol.peerStateInitializer let peerStateInit = protocol.peerStateInitializer
if peerStateInit != nil: if peerStateInit != nil:
peer.protocolStates[protocol.index] = peerStateInit(peer) peer.protocolStates[protocol.index] = peerStateInit(peer)

View File

@ -1,7 +1,7 @@
{.push raises: [Defect].} {.push raises: [Defect].}
import import
std/[options, sequtils], std/[options, sequtils, macrocache],
stew/shims/macros, chronos, faststreams/outputs stew/shims/macros, chronos, faststreams/outputs
type type
@ -76,7 +76,7 @@ type
# Cached properties # Cached properties
nameIdent*: NimNode nameIdent*: NimNode
protocolInfoVar*: NimNode protocolInfo*: NimNode
# All messages # All messages
messages*: seq[Message] messages*: seq[Message]
@ -146,6 +146,9 @@ let
PROTO {.compileTime.} = ident "PROTO" PROTO {.compileTime.} = ident "PROTO"
MSG {.compileTime.} = ident "MSG" MSG {.compileTime.} = ident "MSG"
const
protocolCounter = CacheCounter"protocolCounter"
template Opt(T): auto = newTree(nnkBracketExpr, Option, T) template Opt(T): auto = newTree(nnkBracketExpr, Option, T)
template Fut(T): auto = newTree(nnkBracketExpr, Future, T) template Fut(T): auto = newTree(nnkBracketExpr, Future, T)
@ -253,7 +256,7 @@ proc refreshParam(n: NimNode): NimNode =
result = copyNimTree(n) result = copyNimTree(n)
if n.kind == nnkIdentDefs: if n.kind == nnkIdentDefs:
for i in 0..<n.len-2: for i in 0..<n.len-2:
if n[i].kind == nnkSym: if n[i].kind == nnkSym:
result[i] = genSym(symKind(n[i]), $n[i]) result[i] = genSym(symKind(n[i]), $n[i])
iterator typedInputParams(procDef: NimNode, skip = 0): (NimNode, NimNode) = iterator typedInputParams(procDef: NimNode, skip = 0): (NimNode, NimNode) =
@ -311,7 +314,7 @@ proc init*(T: type P2PProtocol, backendFactory: BackendFactory,
PeerStateType: verifyStateType peerState, PeerStateType: verifyStateType peerState,
NetworkStateType: verifyStateType networkState, NetworkStateType: verifyStateType networkState,
nameIdent: ident(name), nameIdent: ident(name),
protocolInfoVar: ident(name & "Protocol"), protocolInfo: newCall(ident("protocolInfo"), ident(name)),
outSendProcs: newStmtList(), outSendProcs: newStmtList(),
outRecvProcs: newStmtList(), outRecvProcs: newStmtList(),
outProcRegistrations: newStmtList()) outProcRegistrations: newStmtList())
@ -343,7 +346,7 @@ proc augmentUserHandler(p: P2PProtocol, userHandlerProc: NimNode, msgId = -1) =
var var
getState = ident"getState" getState = ident"getState"
getNetworkState = ident"getNetworkState" getNetworkState = ident"getNetworkState"
protocolInfoVar = p.protocolInfoVar protocolInfo = p.protocolInfo
protocolNameIdent = p.nameIdent protocolNameIdent = p.nameIdent
PeerType = p.backend.PeerType PeerType = p.backend.PeerType
PeerStateType = p.PeerStateType PeerStateType = p.PeerStateType
@ -370,12 +373,12 @@ proc augmentUserHandler(p: P2PProtocol, userHandlerProc: NimNode, msgId = -1) =
if PeerStateType != nil: if PeerStateType != nil:
prelude.add quote do: prelude.add quote do:
template state(`peerVar`: `PeerType`): `PeerStateType` = template state(`peerVar`: `PeerType`): `PeerStateType` =
cast[`PeerStateType`](`getState`(`peerVar`, `protocolInfoVar`)) `PeerStateType`(`getState`(`peerVar`, `protocolInfo`))
if NetworkStateType != nil: if NetworkStateType != nil:
prelude.add quote do: prelude.add quote do:
template networkState(`peerVar`: `PeerType`): `NetworkStateType` = template networkState(`peerVar`: `PeerType`): `NetworkStateType` =
cast[`NetworkStateType`](`getNetworkState`(`peerVar`.network, `protocolInfoVar`)) `NetworkStateType`(`getNetworkState`(`peerVar`.network, `protocolInfo`))
proc addPreludeDefs*(userHandlerProc: NimNode, definitions: NimNode) = proc addPreludeDefs*(userHandlerProc: NimNode, definitions: NimNode) =
userHandlerProc.body[0].add definitions userHandlerProc.body[0].add definitions
@ -699,7 +702,7 @@ proc useStandardBody*(sendProc: SendProc,
newStmtList() newStmtList()
else: else:
logSentMsgFields(recipient, logSentMsgFields(recipient,
msg.protocol.protocolInfoVar, msg.protocol.protocolInfo,
$msg.ident, $msg.ident,
sendProc.msgParams) sendProc.msgParams)
@ -895,16 +898,24 @@ proc processProtocolBody*(p: P2PProtocol, protocolBody: NimNode) =
proc genTypeSection*(p: P2PProtocol): NimNode = proc genTypeSection*(p: P2PProtocol): NimNode =
var var
protocolIdx = protocolCounter.value
protocolName = p.nameIdent protocolName = p.nameIdent
peerState = p.PeerStateType peerState = p.PeerStateType
networkState= p.NetworkStateType networkState= p.NetworkStateType
protocolCounter.inc
result = newStmtList() result = newStmtList()
result.add quote do: result.add quote do:
# Create a type acting as a pseudo-object representing the protocol # Create a type acting as a pseudo-object representing the protocol
# (e.g. p2p) # (e.g. p2p)
type `protocolName`* = object type `protocolName`* = object
# The protocol run-time index is available as a pseudo-field
# (e.g. `p2p.index`)
template index*(`PROTO`: type `protocolName`): auto = `protocolIdx`
template protocolInfo*(`PROTO`: type `protocolName`): auto =
getProtocol(`protocolIdx`)
if peerState != nil: if peerState != nil:
result.add quote do: result.add quote do:
template State*(`PROTO`: type `protocolName`): type = `peerState` template State*(`PROTO`: type `protocolName`): type = `peerState`
@ -949,33 +960,29 @@ proc genCode*(p: P2PProtocol): NimNode =
result.add p.genTypeSection() result.add p.genTypeSection()
let let
protocolInfoVar = p.protocolInfoVar
protocolInfoVarObj = ident($protocolInfoVar & "Obj")
protocolName = p.nameIdent
protocolInit = p.backend.implementProtocolInit(p) protocolInit = p.backend.implementProtocolInit(p)
protocolReg = ident($p.nameIdent & "Registration")
result.add quote do: regBody = newStmtList()
# One global variable per protocol holds the protocol run-time data
var `protocolInfoVarObj` = `protocolInit`
var `protocolInfoVar` = addr `protocolInfoVarObj`
# The protocol run-time data is available as a pseudo-field
# (e.g. `p2p.protocolInfo`)
template protocolInfo*(`PROTO`: type `protocolName`): auto = `protocolInfoVar`
result.add p.outSendProcs, result.add p.outSendProcs,
p.outRecvProcs, p.outRecvProcs
p.outProcRegistrations
if p.onPeerConnected != nil: result.add p.onPeerConnected if p.onPeerConnected != nil: result.add p.onPeerConnected
if p.onPeerDisconnected != nil: result.add p.onPeerDisconnected if p.onPeerDisconnected != nil: result.add p.onPeerDisconnected
result.add newCall(p.backend.setEventHandlers, regBody.add newCall(p.backend.setEventHandlers,
protocolInfoVar, protocolVar,
nameOrNil p.onPeerConnected, nameOrNil p.onPeerConnected,
nameOrNil p.onPeerDisconnected) nameOrNil p.onPeerDisconnected)
result.add newCall(p.backend.registerProtocol, protocolInfoVar) regBody.add p.outProcRegistrations
regBody.add newCall(p.backend.registerProtocol, protocolVar)
result.add quote do:
proc `protocolReg`() {.raises: [RlpError, Defect].} =
let `protocolVar` = `protocolInit`
`regBody`
`protocolReg`()
macro emitForSingleBackend( macro emitForSingleBackend(
name: static[string], name: static[string],

View File

@ -93,7 +93,10 @@ type
## Quasy-private types. Use at your own risk. ## Quasy-private types. Use at your own risk.
## ##
ProtocolInfoObj* = object ProtocolManager* = ref object
protocols*: seq[ProtocolInfo]
ProtocolInfo* = ref object
name*: string name*: string
version*: int version*: int
messages*: seq[MessageInfo] messages*: seq[MessageInfo]
@ -106,9 +109,7 @@ type
handshake*: HandshakeStep handshake*: HandshakeStep
disconnectHandler*: DisconnectionHandler disconnectHandler*: DisconnectionHandler
ProtocolInfo* = ptr ProtocolInfoObj MessageInfo* = ref object
MessageInfo* = object
id*: int id*: int
name*: string name*: string
@ -132,7 +133,7 @@ type
# `messages` holds a mapping from valid message IDs to their handler procs. # `messages` holds a mapping from valid message IDs to their handler procs.
# #
protocolOffsets*: seq[int] protocolOffsets*: seq[int]
messages*: seq[ptr MessageInfo] messages*: seq[MessageInfo]
activeProtocols*: seq[ProtocolInfo] activeProtocols*: seq[ProtocolInfo]
## ##

View File

@ -192,9 +192,6 @@ proc handshakeImpl[T](peer: Peer,
else: else:
return responseFut.read return responseFut.read
var gDevp2pInfo: ProtocolInfo
template devp2pInfo: auto = {.gcsafe.}: gDevp2pInfo
# Dispatcher # Dispatcher
# #
@ -220,7 +217,7 @@ proc getDispatcher(node: EthereumNode,
# We should be able to find an existing dispatcher without allocating a new one # We should be able to find an existing dispatcher without allocating a new one
new result new result
newSeq(result.protocolOffsets, allProtocols.len) newSeq(result.protocolOffsets, protocolCount())
result.protocolOffsets.fill -1 result.protocolOffsets.fill -1
var nextUserMsgId = 0x10 var nextUserMsgId = 0x10
@ -237,9 +234,9 @@ proc getDispatcher(node: EthereumNode,
template copyTo(src, dest; index: int) = template copyTo(src, dest; index: int) =
for i in 0 ..< src.len: for i in 0 ..< src.len:
dest[index + i] = addr src[i] dest[index + i] = src[i]
result.messages = newSeq[ptr MessageInfo](nextUserMsgId) result.messages = newSeq[MessageInfo](nextUserMsgId)
devp2pInfo.messages.copyTo(result.messages, 0) devp2pInfo.messages.copyTo(result.messages, 0)
for localProtocol in node.protocols: for localProtocol in node.protocols:
@ -262,30 +259,35 @@ proc getMsgName*(peer: Peer, msgId: int): string =
of 3: "pong" of 3: "pong"
else: $msgId else: $msgId
proc getMsgMetadata*(peer: Peer, msgId: int): (ProtocolInfo, ptr MessageInfo) = proc getMsgMetadata*(peer: Peer, msgId: int): (ProtocolInfo, MessageInfo) =
doAssert msgId >= 0 doAssert msgId >= 0
if msgId <= devp2pInfo.messages[^1].id: let dpInfo = devp2pInfo()
return (devp2pInfo, addr devp2pInfo.messages[msgId]) if msgId <= dpInfo.messages[^1].id:
return (dpInfo, dpInfo.messages[msgId])
if msgId < peer.dispatcher.messages.len: if msgId < peer.dispatcher.messages.len:
for i in 0 ..< allProtocols.len: let numProtocol = protocolCount()
for i in 0 ..< numProtocol:
let protocol = getProtocol(i)
let offset = peer.dispatcher.protocolOffsets[i] let offset = peer.dispatcher.protocolOffsets[i]
if offset != -1 and if offset != -1 and
offset + allProtocols[i].messages[^1].id >= msgId: offset + protocol.messages[^1].id >= msgId:
return (allProtocols[i], peer.dispatcher.messages[msgId]) return (protocol, peer.dispatcher.messages[msgId])
# Protocol info objects # Protocol info objects
# #
proc initProtocol(name: string, version: int, proc initProtocol(name: string, version: int,
peerInit: PeerStateInitializer, peerInit: PeerStateInitializer,
networkInit: NetworkStateInitializer): ProtocolInfoObj = networkInit: NetworkStateInitializer): ProtocolInfo =
result.name = name ProtocolInfo(
result.version = version name : name,
result.messages = @[] version : version,
result.peerStateInitializer = peerInit messages: @[],
result.networkStateInitializer = networkInit peerStateInitializer: peerInit,
networkStateInitializer: networkInit
)
proc setEventHandlers(p: ProtocolInfo, proc setEventHandlers(p: ProtocolInfo,
handshake: HandshakeStep, handshake: HandshakeStep,
@ -321,16 +323,6 @@ proc registerMsg(protocol: ProtocolInfo,
requestResolver: requestResolver, requestResolver: requestResolver,
nextMsgResolver: nextMsgResolver) nextMsgResolver: nextMsgResolver)
proc registerProtocol(protocol: ProtocolInfo) =
# TODO: This can be done at compile-time in the future
if protocol.name != "p2p":
let pos = lowerBound(gProtocols, protocol)
gProtocols.insert(protocol, pos)
for i in 0 ..< gProtocols.len:
gProtocols[i].index = i
else:
gDevp2pInfo = protocol
# Message composition and encryption # Message composition and encryption
# #
@ -973,7 +965,7 @@ proc p2pProtocolBackendImpl*(protocol: P2PProtocol): Backend =
quote: return `sendCall` quote: return `sendCall`
let perPeerMsgIdValue = if isSubprotocol: let perPeerMsgIdValue = if isSubprotocol:
newCall(perPeerMsgIdImpl, peerVar, protocol.protocolInfoVar, newLit(msgId)) newCall(perPeerMsgIdImpl, peerVar, protocol.protocolInfo, newLit(msgId))
else: else:
newLit(msgId) newLit(msgId)
@ -1009,7 +1001,7 @@ proc p2pProtocolBackendImpl*(protocol: P2PProtocol): Backend =
protocol.outProcRegistrations.add( protocol.outProcRegistrations.add(
newCall(registerMsg, newCall(registerMsg,
protocol.protocolInfoVar, protocolVar,
newLit(msgId), newLit(msgId),
newLit(msgName), newLit(msgName),
thunkName, thunkName,
@ -1063,7 +1055,7 @@ proc removePeer(network: EthereumNode, peer: Peer) =
proc callDisconnectHandlers(peer: Peer, reason: DisconnectionReason): proc callDisconnectHandlers(peer: Peer, reason: DisconnectionReason):
Future[void] {.async.} = Future[void] {.async.} =
var futures = newSeqOfCap[Future[void]](allProtocols.len) var futures = newSeqOfCap[Future[void]](protocolCount())
for protocol in peer.dispatcher.activeProtocols: for protocol in peer.dispatcher.activeProtocols:
if protocol.disconnectHandler != nil: if protocol.disconnectHandler != nil:
@ -1144,7 +1136,7 @@ proc postHelloSteps(peer: Peer, h: DevP2P.hello) {.async.} =
# chance to send any initial packages they might require over # chance to send any initial packages they might require over
# the network and to yield on their `nextMsg` waits. # the network and to yield on their `nextMsg` waits.
# #
var subProtocolsHandshakes = newSeqOfCap[Future[void]](allProtocols.len) var subProtocolsHandshakes = newSeqOfCap[Future[void]](protocolCount())
for protocol in peer.dispatcher.activeProtocols: for protocol in peer.dispatcher.activeProtocols:
if protocol.handshake != nil: if protocol.handshake != nil:
subProtocolsHandshakes.add((protocol.handshake)(peer)) subProtocolsHandshakes.add((protocol.handshake)(peer))

View File

@ -6,7 +6,7 @@ import
# real eth protocol implementation is in nimbus-eth1 repo # real eth protocol implementation is in nimbus-eth1 repo
type type
PeerState = ref object PeerState = ref object of RootRef
initialized*: bool initialized*: bool
p2pProtocol eth(version = 63, p2pProtocol eth(version = 63,

View File

@ -16,9 +16,12 @@ import
./p2p_test_helper ./p2p_test_helper
type type
network = ref object network = ref object of RootRef
count*: int count*: int
PeerState = ref object of RootRef
status*: string
p2pProtocol abc(version = 1, p2pProtocol abc(version = 1,
rlpxName = "abc", rlpxName = "abc",
networkState = network): networkState = network):
@ -33,15 +36,18 @@ p2pProtocol abc(version = 1,
p2pProtocol xyz(version = 1, p2pProtocol xyz(version = 1,
rlpxName = "xyz", rlpxName = "xyz",
networkState = network): networkState = network,
peerState = PeerState):
onPeerConnected do (peer: Peer): onPeerConnected do (peer: Peer):
peer.networkState.count += 1 peer.networkState.count += 1
peer.state.status = "connected"
onPeerDisconnected do (peer: Peer, reason: DisconnectionReason) {.gcsafe.}: onPeerDisconnected do (peer: Peer, reason: DisconnectionReason) {.gcsafe.}:
peer.networkState.count -= 1 peer.networkState.count -= 1
if true: if true:
raise newException(CatchableError, "Fake xyz exception") raise newException(CatchableError, "Fake xyz exception")
peer.state.status = "disconnected"
p2pProtocol hah(version = 1, p2pProtocol hah(version = 1,
rlpxName = "hah", rlpxName = "hah",
@ -67,6 +73,7 @@ suite "Testing protocol handlers":
let peer = await node1.rlpxConnect(newNode(node2.toENode())) let peer = await node1.rlpxConnect(newNode(node2.toENode()))
check: check:
peer.isNil == false peer.isNil == false
peer.state(xyz).status == "connected"
await peer.disconnect(SubprotocolReason, true) await peer.disconnect(SubprotocolReason, true)
check: check:
@ -74,6 +81,7 @@ suite "Testing protocol handlers":
# handlers, each handler still ran # handlers, each handler still ran
node1.protocolState(abc).count == 0 node1.protocolState(abc).count == 0
node1.protocolState(xyz).count == 0 node1.protocolState(xyz).count == 0
peer.state(xyz).status == "connected"
asyncTest "Failing connection handler": asyncTest "Failing connection handler":
let rng = newRng() let rng = newRng()