refactor p2pProtocol internals

Nim devel brach(1.7.1) introduce gc=orc as default mode.
Because the p2p protocol using unsafe pointer operations
for it's ProtocolInfo and using global variables scattered
around, the orc mistakenly(or maybe correctly) crash the protocol.
This commit is contained in:
jangko 2022-10-13 09:50:49 +07:00
parent d238693571
commit e1bdf1741a
No known key found for this signature in database
GPG Key ID: 31702AE10541E6B9
7 changed files with 106 additions and 75 deletions

View File

@ -106,7 +106,7 @@ proc newEthereumNode*(
result.protocolVersion = if useCompression: devp2pSnappyVersion
else: devp2pVersion
result.protocolStates.newSeq allProtocols.len
result.protocolStates.newSeq protocolCount()
result.peerPool = newPeerPool(
result, networkId, keys, nil, clientId, minPeers = minPeers)
@ -114,8 +114,8 @@ proc newEthereumNode*(
result.peerPool.discovery = result.discovery
if addAllCapabilities:
for p in allProtocols:
result.addCapability(p)
for cap in protocols():
result.addCapability(cap)
proc processIncoming(server: StreamServer,
remote: StreamTransport): Future[void] {.async, gcsafe.} =

View File

@ -1,9 +1,33 @@
var
gProtocols: seq[ProtocolInfo]
let protocolManager = ProtocolManager()
# The variables above are immutable RTTI information. We need to tell
# Nim to not consider them GcSafe violations:
template allProtocols*: auto = {.gcsafe.}: gProtocols
proc registerProtocol*(proto: ProtocolInfo) {.gcsafe.} =
{.gcsafe.}:
proto.index = protocolManager.protocols.len
if proto.name == "p2p":
doAssert(proto.index == 0)
protocolManager.protocols.add proto
proc protocolCount*(): int {.gcsafe.} =
{.gcsafe.}:
protocolManager.protocols.len
proc getProtocol*(index: int): ProtocolInfo {.gcsafe.} =
{.gcsafe.}:
protocolManager.protocols[index]
iterator protocols*(): ProtocolInfo {.gcsafe.} =
{.gcsafe.}:
for x in protocolManager.protocols:
yield x
template getProtocol*(Protocol: type): ProtocolInfo =
getProtocol(Protocol.index)
template devp2pInfo*(): ProtocolInfo =
getProtocol(0)
proc getState*(peer: Peer, proto: ProtocolInfo): RootRef =
peer.protocolStates[proto.index]
@ -35,9 +59,8 @@ proc initProtocolState*[T](state: T, x: Peer|EthereumNode)
proc initProtocolStates(peer: Peer, protocols: openArray[ProtocolInfo])
{.raises: [Defect].} =
# Initialize all the active protocol states
newSeq(peer.protocolStates, allProtocols.len)
newSeq(peer.protocolStates, protocolCount())
for protocol in protocols:
let peerStateInit = protocol.peerStateInitializer
if peerStateInit != nil:
peer.protocolStates[protocol.index] = peerStateInit(peer)

View File

@ -1,7 +1,7 @@
{.push raises: [Defect].}
import
std/[options, sequtils],
std/[options, sequtils, macrocache],
stew/shims/macros, chronos, faststreams/outputs
type
@ -76,7 +76,7 @@ type
# Cached properties
nameIdent*: NimNode
protocolInfoVar*: NimNode
protocolInfo*: NimNode
# All messages
messages*: seq[Message]
@ -146,6 +146,9 @@ let
PROTO {.compileTime.} = ident "PROTO"
MSG {.compileTime.} = ident "MSG"
const
protocolCounter = CacheCounter"protocolCounter"
template Opt(T): auto = newTree(nnkBracketExpr, Option, T)
template Fut(T): auto = newTree(nnkBracketExpr, Future, T)
@ -311,7 +314,7 @@ proc init*(T: type P2PProtocol, backendFactory: BackendFactory,
PeerStateType: verifyStateType peerState,
NetworkStateType: verifyStateType networkState,
nameIdent: ident(name),
protocolInfoVar: ident(name & "Protocol"),
protocolInfo: newCall(ident("protocolInfo"), ident(name)),
outSendProcs: newStmtList(),
outRecvProcs: newStmtList(),
outProcRegistrations: newStmtList())
@ -343,7 +346,7 @@ proc augmentUserHandler(p: P2PProtocol, userHandlerProc: NimNode, msgId = -1) =
var
getState = ident"getState"
getNetworkState = ident"getNetworkState"
protocolInfoVar = p.protocolInfoVar
protocolInfo = p.protocolInfo
protocolNameIdent = p.nameIdent
PeerType = p.backend.PeerType
PeerStateType = p.PeerStateType
@ -370,12 +373,12 @@ proc augmentUserHandler(p: P2PProtocol, userHandlerProc: NimNode, msgId = -1) =
if PeerStateType != nil:
prelude.add quote do:
template state(`peerVar`: `PeerType`): `PeerStateType` =
cast[`PeerStateType`](`getState`(`peerVar`, `protocolInfoVar`))
`PeerStateType`(`getState`(`peerVar`, `protocolInfo`))
if NetworkStateType != nil:
prelude.add quote do:
template networkState(`peerVar`: `PeerType`): `NetworkStateType` =
cast[`NetworkStateType`](`getNetworkState`(`peerVar`.network, `protocolInfoVar`))
`NetworkStateType`(`getNetworkState`(`peerVar`.network, `protocolInfo`))
proc addPreludeDefs*(userHandlerProc: NimNode, definitions: NimNode) =
userHandlerProc.body[0].add definitions
@ -699,7 +702,7 @@ proc useStandardBody*(sendProc: SendProc,
newStmtList()
else:
logSentMsgFields(recipient,
msg.protocol.protocolInfoVar,
msg.protocol.protocolInfo,
$msg.ident,
sendProc.msgParams)
@ -895,16 +898,24 @@ proc processProtocolBody*(p: P2PProtocol, protocolBody: NimNode) =
proc genTypeSection*(p: P2PProtocol): NimNode =
var
protocolIdx = protocolCounter.value
protocolName = p.nameIdent
peerState = p.PeerStateType
networkState= p.NetworkStateType
protocolCounter.inc
result = newStmtList()
result.add quote do:
# Create a type acting as a pseudo-object representing the protocol
# (e.g. p2p)
type `protocolName`* = object
# The protocol run-time index is available as a pseudo-field
# (e.g. `p2p.index`)
template index*(`PROTO`: type `protocolName`): auto = `protocolIdx`
template protocolInfo*(`PROTO`: type `protocolName`): auto =
getProtocol(`protocolIdx`)
if peerState != nil:
result.add quote do:
template State*(`PROTO`: type `protocolName`): type = `peerState`
@ -949,33 +960,29 @@ proc genCode*(p: P2PProtocol): NimNode =
result.add p.genTypeSection()
let
protocolInfoVar = p.protocolInfoVar
protocolInfoVarObj = ident($protocolInfoVar & "Obj")
protocolName = p.nameIdent
protocolInit = p.backend.implementProtocolInit(p)
result.add quote do:
# One global variable per protocol holds the protocol run-time data
var `protocolInfoVarObj` = `protocolInit`
var `protocolInfoVar` = addr `protocolInfoVarObj`
# The protocol run-time data is available as a pseudo-field
# (e.g. `p2p.protocolInfo`)
template protocolInfo*(`PROTO`: type `protocolName`): auto = `protocolInfoVar`
protocolReg = ident($p.nameIdent & "Registration")
regBody = newStmtList()
result.add p.outSendProcs,
p.outRecvProcs,
p.outProcRegistrations
p.outRecvProcs
if p.onPeerConnected != nil: result.add p.onPeerConnected
if p.onPeerDisconnected != nil: result.add p.onPeerDisconnected
result.add newCall(p.backend.setEventHandlers,
protocolInfoVar,
regBody.add newCall(p.backend.setEventHandlers,
protocolVar,
nameOrNil p.onPeerConnected,
nameOrNil p.onPeerDisconnected)
result.add newCall(p.backend.registerProtocol, protocolInfoVar)
regBody.add p.outProcRegistrations
regBody.add newCall(p.backend.registerProtocol, protocolVar)
result.add quote do:
proc `protocolReg`() {.raises: [RlpError, Defect].} =
let `protocolVar` = `protocolInit`
`regBody`
`protocolReg`()
macro emitForSingleBackend(
name: static[string],

View File

@ -93,7 +93,10 @@ type
## Quasy-private types. Use at your own risk.
##
ProtocolInfoObj* = object
ProtocolManager* = ref object
protocols*: seq[ProtocolInfo]
ProtocolInfo* = ref object
name*: string
version*: int
messages*: seq[MessageInfo]
@ -106,9 +109,7 @@ type
handshake*: HandshakeStep
disconnectHandler*: DisconnectionHandler
ProtocolInfo* = ptr ProtocolInfoObj
MessageInfo* = object
MessageInfo* = ref object
id*: int
name*: string
@ -132,7 +133,7 @@ type
# `messages` holds a mapping from valid message IDs to their handler procs.
#
protocolOffsets*: seq[int]
messages*: seq[ptr MessageInfo]
messages*: seq[MessageInfo]
activeProtocols*: seq[ProtocolInfo]
##

View File

@ -192,9 +192,6 @@ proc handshakeImpl[T](peer: Peer,
else:
return responseFut.read
var gDevp2pInfo: ProtocolInfo
template devp2pInfo: auto = {.gcsafe.}: gDevp2pInfo
# Dispatcher
#
@ -220,7 +217,7 @@ proc getDispatcher(node: EthereumNode,
# We should be able to find an existing dispatcher without allocating a new one
new result
newSeq(result.protocolOffsets, allProtocols.len)
newSeq(result.protocolOffsets, protocolCount())
result.protocolOffsets.fill -1
var nextUserMsgId = 0x10
@ -237,9 +234,9 @@ proc getDispatcher(node: EthereumNode,
template copyTo(src, dest; index: int) =
for i in 0 ..< src.len:
dest[index + i] = addr src[i]
dest[index + i] = src[i]
result.messages = newSeq[ptr MessageInfo](nextUserMsgId)
result.messages = newSeq[MessageInfo](nextUserMsgId)
devp2pInfo.messages.copyTo(result.messages, 0)
for localProtocol in node.protocols:
@ -262,30 +259,35 @@ proc getMsgName*(peer: Peer, msgId: int): string =
of 3: "pong"
else: $msgId
proc getMsgMetadata*(peer: Peer, msgId: int): (ProtocolInfo, ptr MessageInfo) =
proc getMsgMetadata*(peer: Peer, msgId: int): (ProtocolInfo, MessageInfo) =
doAssert msgId >= 0
if msgId <= devp2pInfo.messages[^1].id:
return (devp2pInfo, addr devp2pInfo.messages[msgId])
let dpInfo = devp2pInfo()
if msgId <= dpInfo.messages[^1].id:
return (dpInfo, dpInfo.messages[msgId])
if msgId < peer.dispatcher.messages.len:
for i in 0 ..< allProtocols.len:
let numProtocol = protocolCount()
for i in 0 ..< numProtocol:
let protocol = getProtocol(i)
let offset = peer.dispatcher.protocolOffsets[i]
if offset != -1 and
offset + allProtocols[i].messages[^1].id >= msgId:
return (allProtocols[i], peer.dispatcher.messages[msgId])
offset + protocol.messages[^1].id >= msgId:
return (protocol, peer.dispatcher.messages[msgId])
# Protocol info objects
#
proc initProtocol(name: string, version: int,
peerInit: PeerStateInitializer,
networkInit: NetworkStateInitializer): ProtocolInfoObj =
result.name = name
result.version = version
result.messages = @[]
result.peerStateInitializer = peerInit
result.networkStateInitializer = networkInit
networkInit: NetworkStateInitializer): ProtocolInfo =
ProtocolInfo(
name : name,
version : version,
messages: @[],
peerStateInitializer: peerInit,
networkStateInitializer: networkInit
)
proc setEventHandlers(p: ProtocolInfo,
handshake: HandshakeStep,
@ -321,16 +323,6 @@ proc registerMsg(protocol: ProtocolInfo,
requestResolver: requestResolver,
nextMsgResolver: nextMsgResolver)
proc registerProtocol(protocol: ProtocolInfo) =
# TODO: This can be done at compile-time in the future
if protocol.name != "p2p":
let pos = lowerBound(gProtocols, protocol)
gProtocols.insert(protocol, pos)
for i in 0 ..< gProtocols.len:
gProtocols[i].index = i
else:
gDevp2pInfo = protocol
# Message composition and encryption
#
@ -973,7 +965,7 @@ proc p2pProtocolBackendImpl*(protocol: P2PProtocol): Backend =
quote: return `sendCall`
let perPeerMsgIdValue = if isSubprotocol:
newCall(perPeerMsgIdImpl, peerVar, protocol.protocolInfoVar, newLit(msgId))
newCall(perPeerMsgIdImpl, peerVar, protocol.protocolInfo, newLit(msgId))
else:
newLit(msgId)
@ -1009,7 +1001,7 @@ proc p2pProtocolBackendImpl*(protocol: P2PProtocol): Backend =
protocol.outProcRegistrations.add(
newCall(registerMsg,
protocol.protocolInfoVar,
protocolVar,
newLit(msgId),
newLit(msgName),
thunkName,
@ -1063,7 +1055,7 @@ proc removePeer(network: EthereumNode, peer: Peer) =
proc callDisconnectHandlers(peer: Peer, reason: DisconnectionReason):
Future[void] {.async.} =
var futures = newSeqOfCap[Future[void]](allProtocols.len)
var futures = newSeqOfCap[Future[void]](protocolCount())
for protocol in peer.dispatcher.activeProtocols:
if protocol.disconnectHandler != nil:
@ -1144,7 +1136,7 @@ proc postHelloSteps(peer: Peer, h: DevP2P.hello) {.async.} =
# chance to send any initial packages they might require over
# the network and to yield on their `nextMsg` waits.
#
var subProtocolsHandshakes = newSeqOfCap[Future[void]](allProtocols.len)
var subProtocolsHandshakes = newSeqOfCap[Future[void]](protocolCount())
for protocol in peer.dispatcher.activeProtocols:
if protocol.handshake != nil:
subProtocolsHandshakes.add((protocol.handshake)(peer))

View File

@ -6,7 +6,7 @@ import
# real eth protocol implementation is in nimbus-eth1 repo
type
PeerState = ref object
PeerState = ref object of RootRef
initialized*: bool
p2pProtocol eth(version = 63,

View File

@ -16,9 +16,12 @@ import
./p2p_test_helper
type
network = ref object
network = ref object of RootRef
count*: int
PeerState = ref object of RootRef
status*: string
p2pProtocol abc(version = 1,
rlpxName = "abc",
networkState = network):
@ -33,15 +36,18 @@ p2pProtocol abc(version = 1,
p2pProtocol xyz(version = 1,
rlpxName = "xyz",
networkState = network):
networkState = network,
peerState = PeerState):
onPeerConnected do (peer: Peer):
peer.networkState.count += 1
peer.state.status = "connected"
onPeerDisconnected do (peer: Peer, reason: DisconnectionReason) {.gcsafe.}:
peer.networkState.count -= 1
if true:
raise newException(CatchableError, "Fake xyz exception")
peer.state.status = "disconnected"
p2pProtocol hah(version = 1,
rlpxName = "hah",
@ -67,6 +73,7 @@ suite "Testing protocol handlers":
let peer = await node1.rlpxConnect(newNode(node2.toENode()))
check:
peer.isNil == false
peer.state(xyz).status == "connected"
await peer.disconnect(SubprotocolReason, true)
check:
@ -74,6 +81,7 @@ suite "Testing protocol handlers":
# handlers, each handler still ran
node1.protocolState(abc).count == 0
node1.protocolState(xyz).count == 0
peer.state(xyz).status == "connected"
asyncTest "Failing connection handler":
let rng = newRng()