Concurrent dials (#238)

* count published messages

* don't call `switch.dial` in `subscribeToPeer`

* add secureconn constructor

* close in the correct order

* concurent dial lock and track in/out conns better

* make tests pass

* add todo comment

* disconect peers that open too many connections

* wip

* do connection and muxer tracking in one place

* prevent nil pointer in observers

* drop connections when peers is over max

* prevent channel leaks

* don't use closure to handle channel
This commit is contained in:
Dmitriy Ryajov 2020-06-24 09:08:44 -06:00 committed by GitHub
parent 83b6ead857
commit 7a95f1844b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 309 additions and 142 deletions

View File

@ -25,7 +25,6 @@ type
Mplex* = ref object of Muxer Mplex* = ref object of Muxer
remote: Table[uint64, LPChannel] remote: Table[uint64, LPChannel]
local: Table[uint64, LPChannel] local: Table[uint64, LPChannel]
handlerFuts: seq[Future[void]]
currentId*: uint64 currentId*: uint64
maxChannels*: uint64 maxChannels*: uint64
isClosed: bool isClosed: bool
@ -66,6 +65,15 @@ proc newStreamInternal*(m: Mplex,
m.getChannelList(initiator)[id] = result m.getChannelList(initiator)[id] = result
proc handleStream(m: Muxer, chann: LPChannel) {.async.} =
try:
await m.streamHandler(chann)
trace "finished handling stream"
doAssert(chann.closed, "connection not closed by handler!")
except CatchableError as exc:
trace "exception in stream handler", exc = exc.msg
await chann.reset()
method handle*(m: Mplex) {.async, gcsafe.} = method handle*(m: Mplex) {.async, gcsafe.} =
trace "starting mplex main loop", oid = m.oid trace "starting mplex main loop", oid = m.oid
try: try:
@ -96,7 +104,7 @@ method handle*(m: Mplex) {.async, gcsafe.} =
initiator = initiator initiator = initiator
msgType = msgType msgType = msgType
size = data.len size = data.len
oid = m.oid muxer_oid = m.oid
case msgType: case msgType:
of MessageType.New: of MessageType.New:
@ -104,27 +112,16 @@ method handle*(m: Mplex) {.async, gcsafe.} =
channel = await m.newStreamInternal(false, id, name) channel = await m.newStreamInternal(false, id, name)
trace "created channel", name = channel.name, trace "created channel", name = channel.name,
chann_iod = channel.oid oid = channel.oid
if not isNil(m.streamHandler): if not isNil(m.streamHandler):
var fut = newFuture[void]() # launch handler task
proc handler() {.async.} = asyncCheck m.handleStream(channel)
try:
await m.streamHandler(channel)
trace "finished handling stream"
# doAssert(channel.closed, "connection not closed by handler!")
except CatchableError as exc:
trace "exception in stream handler", exc = exc.msg
await channel.reset()
finally:
m.handlerFuts.keepItIf(it != fut)
fut = handler()
of MessageType.MsgIn, MessageType.MsgOut: of MessageType.MsgIn, MessageType.MsgOut:
logScope: logScope:
name = channel.name name = channel.name
chann_iod = channel.oid oid = channel.oid
trace "pushing data to channel" trace "pushing data to channel"
@ -134,7 +131,7 @@ method handle*(m: Mplex) {.async, gcsafe.} =
of MessageType.CloseIn, MessageType.CloseOut: of MessageType.CloseIn, MessageType.CloseOut:
logScope: logScope:
name = channel.name name = channel.name
chann_iod = channel.oid oid = channel.oid
trace "closing channel" trace "closing channel"
@ -144,7 +141,7 @@ method handle*(m: Mplex) {.async, gcsafe.} =
of MessageType.ResetIn, MessageType.ResetOut: of MessageType.ResetIn, MessageType.ResetOut:
logScope: logScope:
name = channel.name name = channel.name
chann_iod = channel.oid oid = channel.oid
trace "resetting channel" trace "resetting channel"
@ -201,12 +198,9 @@ method close*(m: Mplex) {.async, gcsafe.} =
except CatchableError as exc: except CatchableError as exc:
warn "error resetting channel", exc = exc.msg warn "error resetting channel", exc = exc.msg
checkFutures(
await allFinished(m.handlerFuts))
await m.connection.close() await m.connection.close()
finally: finally:
m.remote.clear() m.remote.clear()
m.local.clear() m.local.clear()
m.handlerFuts = @[] # m.handlerFuts = @[]
m.isClosed = true m.isClosed = true

View File

@ -60,12 +60,14 @@ proc recvObservers(p: PubSubPeer, msg: var RPCMsg) =
# trigger hooks # trigger hooks
if not(isNil(p.observers)) and p.observers[].len > 0: if not(isNil(p.observers)) and p.observers[].len > 0:
for obs in p.observers[]: for obs in p.observers[]:
if not(isNil(obs)): # TODO: should never be nil, but...
obs.onRecv(p, msg) obs.onRecv(p, msg)
proc sendObservers(p: PubSubPeer, msg: var RPCMsg) = proc sendObservers(p: PubSubPeer, msg: var RPCMsg) =
# trigger hooks # trigger hooks
if not(isNil(p.observers)) and p.observers[].len > 0: if not(isNil(p.observers)) and p.observers[].len > 0:
for obs in p.observers[]: for obs in p.observers[]:
if not(isNil(obs)): # TODO: should never be nil, but...
obs.onSend(p, msg) obs.onSend(p, msg)
proc handle*(p: PubSubPeer, conn: Connection) {.async.} = proc handle*(p: PubSubPeer, conn: Connection) {.async.} =

View File

@ -467,13 +467,9 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon
debug "Noise handshake, peer infos don't match!", initiator, dealt_peer = $conn.peerInfo.id, dealt_key = $failedKey, received_peer = $pid, received_key = $remotePubKey debug "Noise handshake, peer infos don't match!", initiator, dealt_peer = $conn.peerInfo.id, dealt_key = $failedKey, received_peer = $pid, received_key = $remotePubKey
raise newException(NoiseHandshakeError, "Noise handshake, peer infos don't match! " & $pid & " != " & $conn.peerInfo.peerId) raise newException(NoiseHandshakeError, "Noise handshake, peer infos don't match! " & $pid & " != " & $conn.peerInfo.peerId)
var secure = new NoiseConnection var secure = NoiseConnection.init(conn,
secure.initStream() PeerInfo.init(remotePubKey),
conn.observedAddr)
secure.stream = conn
secure.peerInfo = PeerInfo.init(remotePubKey)
secure.observedAddr = conn.observedAddr
if initiator: if initiator:
secure.readCs = handshakeRes.cs2 secure.readCs = handshakeRes.cs2
secure.writeCs = handshakeRes.cs1 secure.writeCs = handshakeRes.cs1

View File

@ -245,9 +245,9 @@ proc newSecioConn(conn: Connection,
## Create new secure stream/lpstream, using specified hash algorithm ``hash``, ## Create new secure stream/lpstream, using specified hash algorithm ``hash``,
## cipher algorithm ``cipher``, stretched keys ``secrets`` and order ## cipher algorithm ``cipher``, stretched keys ``secrets`` and order
## ``order``. ## ``order``.
new result result = SecioConn.init(conn,
result.initStream() PeerInfo.init(remotePubKey),
result.stream = conn conn.observedAddr)
let i0 = if order < 0: 1 else: 0 let i0 = if order < 0: 1 else: 0
let i1 = if order < 0: 0 else: 1 let i1 = if order < 0: 0 else: 1
@ -265,9 +265,6 @@ proc newSecioConn(conn: Connection,
result.readerCoder.init(cipher, secrets.keyOpenArray(i1), result.readerCoder.init(cipher, secrets.keyOpenArray(i1),
secrets.ivOpenArray(i1)) secrets.ivOpenArray(i1))
result.peerInfo = PeerInfo.init(remotePubKey)
result.observedAddr = conn.observedAddr
proc transactMessage(conn: Connection, proc transactMessage(conn: Connection,
msg: seq[byte]): Future[seq[byte]] {.async.} = msg: seq[byte]): Future[seq[byte]] {.async.} =
trace "Sending message", message = msg.shortLog, length = len(msg) trace "Sending message", message = msg.shortLog, length = len(msg)

View File

@ -12,6 +12,7 @@ import chronos, chronicles
import ../protocol, import ../protocol,
../../stream/streamseq, ../../stream/streamseq,
../../stream/connection, ../../stream/connection,
../../multiaddress,
../../peerinfo ../../peerinfo
logScope: logScope:
@ -24,6 +25,16 @@ type
stream*: Connection stream*: Connection
buf: StreamSeq buf: StreamSeq
proc init*[T: SecureConn](C: type T,
conn: Connection,
peerInfo: PeerInfo,
observedAddr: Multiaddress): T =
result = C(stream: conn,
peerInfo: peerInfo,
observedAddr: observedAddr,
closeEvent: conn.closeEvent)
result.initStream()
method initStream*(s: SecureConn) = method initStream*(s: SecureConn) =
if s.objName.len == 0: if s.objName.len == 0:
s.objName = "SecureConn" s.objName = "SecureConn"
@ -31,11 +42,11 @@ method initStream*(s: SecureConn) =
procCall Connection(s).initStream() procCall Connection(s).initStream()
method close*(s: SecureConn) {.async.} = method close*(s: SecureConn) {.async.} =
await procCall Connection(s).close()
if not(isNil(s.stream)): if not(isNil(s.stream)):
await s.stream.close() await s.stream.close()
await procCall Connection(s).close()
method readMessage*(c: SecureConn): Future[seq[byte]] {.async, base.} = method readMessage*(c: SecureConn): Future[seq[byte]] {.async, base.} =
doAssert(false, "Not implemented!") doAssert(false, "Not implemented!")
@ -47,11 +58,12 @@ method handshake(s: Secure,
proc handleConn*(s: Secure, conn: Connection, initiator: bool): Future[Connection] {.async, gcsafe.} = proc handleConn*(s: Secure, conn: Connection, initiator: bool): Future[Connection] {.async, gcsafe.} =
var sconn = await s.handshake(conn, initiator) var sconn = await s.handshake(conn, initiator)
result = sconn conn.closeEvent.wait()
result.observedAddr = conn.observedAddr .addCallback do(udata: pointer = nil):
if not(isNil(sconn)):
asyncCheck sconn.close()
if not isNil(sconn.peerInfo) and sconn.peerInfo.publicKey.isSome: return sconn
result.peerInfo = PeerInfo.init(sconn.peerInfo.publicKey.get())
method init*(s: Secure) {.gcsafe.} = method init*(s: Secure) {.gcsafe.} =
proc handle(conn: Connection, proto: string) {.async, gcsafe.} = proc handle(conn: Connection, proto: string) {.async, gcsafe.} =
@ -94,7 +106,7 @@ method readExactly*(s: SecureConn,
let consumed = s.buf.consumeTo(toOpenArray(p, 0, nbytes - 1)) let consumed = s.buf.consumeTo(toOpenArray(p, 0, nbytes - 1))
doAssert consumed == nbytes, "checked above" doAssert consumed == nbytes, "checked above"
except CatchableError as exc: except CatchableError as exc:
trace "exception reading from secure connection", exc = exc.msg trace "exception reading from secure connection", exc = exc.msg, oid = s.oid
await s.close() # make sure to close the wrapped connection await s.close() # make sure to close the wrapped connection
raise exc raise exc
@ -115,6 +127,6 @@ method readOnce*(s: SecureConn,
var p = cast[ptr UncheckedArray[byte]](pbytes) var p = cast[ptr UncheckedArray[byte]](pbytes)
return s.buf.consumeTo(toOpenArray(p, 0, nbytes - 1)) return s.buf.consumeTo(toOpenArray(p, 0, nbytes - 1))
except CatchableError as exc: except CatchableError as exc:
trace "exception reading from secure connection", exc = exc.msg trace "exception reading from secure connection", exc = exc.msg, oid = s.oid
await s.close() # make sure to close the wrapped connection await s.close() # make sure to close the wrapped connection
raise exc raise exc

View File

@ -82,12 +82,11 @@ method atEof*(s: ChronosStream): bool {.inline.} =
method close*(s: ChronosStream) {.async.} = method close*(s: ChronosStream) {.async.} =
try: try:
if not s.isClosed: if not s.isClosed:
s.isClosed = true await procCall Connection(s).close()
trace "shutting down chronos stream", address = $s.client.remoteAddress() trace "shutting down chronos stream", address = $s.client.remoteAddress(), oid = s.oid
if not s.client.closed(): if not s.client.closed():
await s.client.closeWait() await s.client.closeWait()
await procCall Connection(s).close()
except CatchableError as exc: except CatchableError as exc:
trace "error closing chronosstream", exc = exc.msg trace "error closing chronosstream", exc = exc.msg

View File

@ -21,7 +21,6 @@ type
Connection* = ref object of LPStream Connection* = ref object of LPStream
peerInfo*: PeerInfo peerInfo*: PeerInfo
observedAddr*: Multiaddress observedAddr*: Multiaddress
closeEvent*: AsyncEvent
ConnectionTracker* = ref object of TrackerBase ConnectionTracker* = ref object of TrackerBase
opened*: uint64 opened*: uint64
@ -65,8 +64,6 @@ method initStream*(s: Connection) =
method close*(s: Connection) {.async.} = method close*(s: Connection) {.async.} =
await procCall LPStream(s).close() await procCall LPStream(s).close()
s.closeEvent.fire()
inc getConnectionTracker().closed inc getConnectionTracker().closed
proc `$`*(conn: Connection): string = proc `$`*(conn: Connection): string =

View File

@ -18,6 +18,7 @@ declareGauge(libp2p_open_streams, "open stream instances", labels = ["type"])
type type
LPStream* = ref object of RootObj LPStream* = ref object of RootObj
closeEvent*: AsyncEvent
isClosed*: bool isClosed*: bool
isEof*: bool isEof*: bool
objName*: string objName*: string
@ -73,7 +74,19 @@ method initStream*(s: LPStream) {.base.} =
s.oid = genOid() s.oid = genOid()
libp2p_open_streams.inc(labelValues = [s.objName]) libp2p_open_streams.inc(labelValues = [s.objName])
trace "stream created", oid = s.oid trace "stream created", oid = s.oid, name = s.objName
# TODO: debuging aid to troubleshoot streams open/close
# try:
# echo "ChronosStream ", libp2p_open_streams.value(labelValues = ["ChronosStream"])
# echo "SecureConn ", libp2p_open_streams.value(labelValues = ["SecureConn"])
# # doAssert(libp2p_open_streams.value(labelValues = ["ChronosStream"]) >=
# # libp2p_open_streams.value(labelValues = ["SecureConn"]))
# except CatchableError:
# discard
proc join*(s: LPStream): Future[void] =
s.closeEvent.wait()
method closed*(s: LPStream): bool {.base, inline.} = method closed*(s: LPStream): bool {.base, inline.} =
s.isClosed s.isClosed
@ -169,6 +182,16 @@ proc write*(s: LPStream, msg: string): Future[void] =
method close*(s: LPStream) {.base, async.} = method close*(s: LPStream) {.base, async.} =
if not s.isClosed: if not s.isClosed:
libp2p_open_streams.dec(labelValues = [s.objName])
s.isClosed = true s.isClosed = true
trace "stream destroyed", oid = s.oid s.closeEvent.fire()
libp2p_open_streams.dec(labelValues = [s.objName])
trace "stream destroyed", oid = s.oid, name = s.objName
# TODO: debuging aid to troubleshoot streams open/close
# try:
# echo "ChronosStream ", libp2p_open_streams.value(labelValues = ["ChronosStream"])
# echo "SecureConn ", libp2p_open_streams.value(labelValues = ["SecureConn"])
# # doAssert(libp2p_open_streams.value(labelValues = ["ChronosStream"]) >=
# # libp2p_open_streams.value(labelValues = ["SecureConn"]))
# except CatchableError:
# discard

View File

@ -7,8 +7,17 @@
## This file may not be copied, modified, or distributed except according to ## This file may not be copied, modified, or distributed except according to
## those terms. ## those terms.
import tables, sequtils, options, strformat, sets import tables,
import chronos, chronicles, metrics sequtils,
options,
strformat,
sets,
algorithm
import chronos,
chronicles,
metrics
import stream/connection, import stream/connection,
stream/chronosstream, stream/chronosstream,
transports/transport, transports/transport,
@ -38,13 +47,28 @@ declareCounter(libp2p_dialed_peers, "dialed peers")
declareCounter(libp2p_failed_dials, "failed dials") declareCounter(libp2p_failed_dials, "failed dials")
declareCounter(libp2p_failed_upgrade, "peers failed upgrade") declareCounter(libp2p_failed_upgrade, "peers failed upgrade")
const MaxConnectionsPerPeer = 5
type type
NoPubSubException = object of CatchableError NoPubSubException* = object of CatchableError
TooManyConnections* = object of CatchableError
Direction {.pure.} = enum
In, Out
ConnectionHolder = object
dir: Direction
conn: Connection
MuxerHolder = object
dir: Direction
muxer: Muxer
handle: Future[void]
Switch* = ref object of RootObj Switch* = ref object of RootObj
peerInfo*: PeerInfo peerInfo*: PeerInfo
connections*: Table[string, Connection] connections*: Table[string, seq[ConnectionHolder]]
muxed*: Table[string, Muxer] muxed*: Table[string, seq[MuxerHolder]]
transports*: seq[Transport] transports*: seq[Transport]
protocols*: seq[LPProtocol] protocols*: seq[LPProtocol]
muxers*: Table[string, MuxerProvider] muxers*: Table[string, MuxerProvider]
@ -54,10 +78,84 @@ type
secureManagers*: seq[Secure] secureManagers*: seq[Secure]
pubSub*: Option[PubSub] pubSub*: Option[PubSub]
dialedPubSubPeers: HashSet[string] dialedPubSubPeers: HashSet[string]
dialLock: Table[string, AsyncLock]
proc newNoPubSubException(): ref CatchableError {.inline.} = proc newNoPubSubException(): ref NoPubSubException {.inline.} =
result = newException(NoPubSubException, "no pubsub provided!") result = newException(NoPubSubException, "no pubsub provided!")
proc newTooManyConnections(): ref TooManyConnections {.inline.} =
result = newException(TooManyConnections, "too many connections for peer")
proc disconnect*(s: Switch, peer: PeerInfo) {.async, gcsafe.}
proc selectConn(s: Switch, peerInfo: PeerInfo): Connection =
## select the "best" connection according to some criteria
##
## Ideally when the connection's stats are available
## we'd select the fastest, but for now we simply pick an outgoing
## connection first if none is available, we pick the first outgoing
##
if isNil(peerInfo):
return
let conns = s.connections
.getOrDefault(peerInfo.id)
# it should be OK to sort on each
# access as there should only be
# up to MaxConnectionsPerPeer entries
.sorted(
proc(a, b: ConnectionHolder): int =
if a.dir < b.dir: -1
elif a.dir == b.dir: 0
else: 1
, SortOrder.Descending)
if conns.len > 0:
return conns[0].conn
proc selectMuxer(s: Switch, conn: Connection): Muxer =
## select the muxer for the supplied connection
##
if isNil(conn):
return
if not(isNil(conn.peerInfo)) and conn.peerInfo.id in s.muxed:
if s.muxed[conn.peerInfo.id].len > 0:
let muxers = s.muxed[conn.peerInfo.id]
.filterIt( it.muxer.connection == conn )
if muxers.len > 0:
return muxers[0].muxer
proc storeConn(s: Switch,
muxer: Muxer,
dir: Direction,
handle: Future[void] = nil) {.async.} =
## store the connection and muxer
##
if not(isNil(muxer)):
let conn = muxer.connection
if not(isNil(conn)):
let id = conn.peerInfo.id
if s.connections.getOrDefault(id).len > MaxConnectionsPerPeer:
warn "disconnecting peer, too many connections", peer = $conn.peerInfo,
conns = s.connections
.getOrDefault(id).len
await muxer.close()
await s.disconnect(conn.peerInfo)
raise newTooManyConnections()
s.connections.mgetOrPut(
id,
newSeq[ConnectionHolder]())
.add(ConnectionHolder(conn: conn, dir: dir))
s.muxed.mgetOrPut(
muxer.connection.peerInfo.id,
newSeq[MuxerHolder]())
.add(MuxerHolder(muxer: muxer, handle: handle, dir: dir))
proc secure(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} = proc secure(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} =
if s.secureManagers.len <= 0: if s.secureManagers.len <= 0:
raise newException(CatchableError, "No secure managers registered!") raise newException(CatchableError, "No secure managers registered!")
@ -137,11 +235,6 @@ proc mux(s: Switch, conn: Connection): Future[void] {.async, gcsafe.} =
# not end until muxer ends # not end until muxer ends
let handlerFut = muxer.handle() let handlerFut = muxer.handle()
# add muxer handler cleanup proc
handlerFut.addCallback do (udata: pointer = nil):
trace "muxer handler completed for peer",
peer = conn.peerInfo.id
try: try:
# do identify first, so that we have a # do identify first, so that we have a
# PeerInfo in case we didn't before # PeerInfo in case we didn't before
@ -149,10 +242,13 @@ proc mux(s: Switch, conn: Connection): Future[void] {.async, gcsafe.} =
finally: finally:
await stream.close() # close identify stream await stream.close() # close identify stream
if isNil(conn.peerInfo):
await muxer.close()
return
# store it in muxed connections if we have a peer for it # store it in muxed connections if we have a peer for it
if not isNil(conn.peerInfo):
trace "adding muxer for peer", peer = conn.peerInfo.id trace "adding muxer for peer", peer = conn.peerInfo.id
s.muxed[conn.peerInfo.id] = muxer await s.storeConn(muxer, Direction.Out, handlerFut)
proc cleanupConn(s: Switch, conn: Connection) {.async, gcsafe.} = proc cleanupConn(s: Switch, conn: Connection) {.async, gcsafe.} =
try: try:
@ -160,55 +256,82 @@ proc cleanupConn(s: Switch, conn: Connection) {.async, gcsafe.} =
let id = conn.peerInfo.id let id = conn.peerInfo.id
trace "cleaning up connection for peer", peerId = id trace "cleaning up connection for peer", peerId = id
if id in s.muxed: if id in s.muxed:
await s.muxed[id].close() let muxerHolder = s.muxed[id]
.filterIt(
it.muxer.connection == conn
)
if muxerHolder.len > 0:
await muxerHolder[0].muxer.close()
if not(isNil(muxerHolder[0].handle)):
await muxerHolder[0].handle
s.muxed[id].keepItIf(
it.muxer.connection != conn
)
if s.muxed[id].len == 0:
s.muxed.del(id) s.muxed.del(id)
if id in s.connections: if id in s.connections:
s.connections[id].keepItIf(
it.conn != conn
)
if s.connections[id].len == 0:
s.connections.del(id) s.connections.del(id)
await conn.close() await conn.close()
s.dialedPubSubPeers.excl(id) s.dialedPubSubPeers.excl(id)
libp2p_peers.dec()
# TODO: Investigate cleanupConn() always called twice for one peer. # TODO: Investigate cleanupConn() always called twice for one peer.
if not(conn.peerInfo.isClosed()): if not(conn.peerInfo.isClosed()):
conn.peerInfo.close() conn.peerInfo.close()
except CatchableError as exc: except CatchableError as exc:
trace "exception cleaning up connection", exc = exc.msg trace "exception cleaning up connection", exc = exc.msg
finally:
libp2p_peers.set(s.connections.len.int64)
proc disconnect*(s: Switch, peer: PeerInfo) {.async, gcsafe.} = proc disconnect*(s: Switch, peer: PeerInfo) {.async, gcsafe.} =
let conn = s.connections.getOrDefault(peer.id) let connections = s.connections.getOrDefault(peer.id)
if not isNil(conn): for connHolder in connections:
trace "disconnecting peer", peer = $peer if not isNil(connHolder.conn):
await s.cleanupConn(conn) await s.cleanupConn(connHolder.conn)
proc getMuxedStream(s: Switch, peerInfo: PeerInfo): Future[Connection] {.async, gcsafe.} = proc getMuxedStream(s: Switch, peerInfo: PeerInfo): Future[Connection] {.async, gcsafe.} =
# if there is a muxer for the connection # if there is a muxer for the connection
# use it instead to create a muxed stream # use it instead to create a muxed stream
if peerInfo.id in s.muxed:
trace "connection is muxed, setting up a stream" let muxer = s.selectMuxer(s.selectConn(peerInfo)) # always get the first muxer here
let muxer = s.muxed[peerInfo.id] if not(isNil(muxer)):
let conn = await muxer.newStream() return await muxer.newStream()
result = conn
proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} = proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} =
trace "handling connection", conn = $conn trace "handling connection", conn = $conn, oid = conn.oid
result = conn
# don't mux/secure twise let sconn = await s.secure(conn) # secure the connection
if conn.peerInfo.id in s.muxed: if isNil(sconn):
trace "unable to secure connection, stopping upgrade", conn = $conn,
oid = conn.oid
await conn.close()
return return
result = await s.secure(result) # secure the connection await s.mux(sconn) # mux it if possible
if isNil(result): if isNil(conn.peerInfo):
trace "unable to mux connection, stopping upgrade", conn = $conn,
oid = conn.oid
await sconn.close()
return return
await s.mux(result) # mux it if possible libp2p_peers.set(s.connections.len.int64)
s.connections[conn.peerInfo.id] = result trace "succesfully upgraded outgoing connection", conn = $conn,
oid = conn.oid,
uoid = sconn.oid
result = sconn
proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} = proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} =
trace "upgrading incoming connection", conn = $conn trace "upgrading incoming connection", conn = $conn, oid = conn.oid
let ms = newMultistream() let ms = newMultistream()
# secure incoming connections # secure incoming connections
@ -216,7 +339,7 @@ proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} =
proto: string) proto: string)
{.async, gcsafe, closure.} = {.async, gcsafe, closure.} =
try: try:
trace "Securing connection" trace "Securing connection", oid = conn.oid
let secure = s.secureManagers.filterIt(it.codec == proto)[0] let secure = s.secureManagers.filterIt(it.codec == proto)[0]
let sconn = await secure.secure(conn, false) let sconn = await secure.secure(conn, false)
if sconn.isNil: if sconn.isNil:
@ -257,10 +380,20 @@ proc subscribeToPeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.}
proc internalConnect(s: Switch, proc internalConnect(s: Switch,
peer: PeerInfo): Future[Connection] {.async.} = peer: PeerInfo): Future[Connection] {.async.} =
if s.peerInfo.peerId == peer.peerId:
raise newException(CatchableError, "can't dial self!")
let id = peer.id let id = peer.id
trace "Dialing peer", peer = id let lock = s.dialLock.mgetOrPut(id, newAsyncLock())
var conn = s.connections.getOrDefault(id) var conn: Connection
try:
await lock.acquire()
trace "about to dial peer", peer = id
conn = s.selectConn(peer)
if conn.isNil or conn.closed: if conn.isNil or conn.closed:
trace "Dialing peer", peer = id
for t in s.transports: # for each transport for t in s.transports: # for each transport
for a in peer.addrs: # for each address for a in peer.addrs: # for each address
if t.handles(a): # check if it can dial it if t.handles(a): # check if it can dial it
@ -275,7 +408,6 @@ proc internalConnect(s: Switch,
# make sure to assign the peer to the connection # make sure to assign the peer to the connection
conn.peerInfo = peer conn.peerInfo = peer
conn = await s.upgradeOutgoing(conn) conn = await s.upgradeOutgoing(conn)
if isNil(conn): if isNil(conn):
libp2p_failed_upgrade.inc() libp2p_failed_upgrade.inc()
@ -284,15 +416,23 @@ proc internalConnect(s: Switch,
conn.closeEvent.wait() conn.closeEvent.wait()
.addCallback do(udata: pointer): .addCallback do(udata: pointer):
asyncCheck s.cleanupConn(conn) asyncCheck s.cleanupConn(conn)
libp2p_peers.inc()
break break
else: else:
trace "Reusing existing connection" trace "Reusing existing connection", oid = conn.oid
except CatchableError as exc:
trace "exception connecting to peer", exc = exc.msg
if not(isNil(conn)):
await conn.close()
raise exc # re-raise
finally:
if lock.locked():
lock.release()
if not isNil(conn): if not isNil(conn):
doAssert(conn.peerInfo.id in s.connections, "connection not tracked!")
trace "dial succesfull", oid = conn.oid
await s.subscribeToPeer(peer) await s.subscribeToPeer(peer)
result = conn result = conn
proc connect*(s: Switch, peer: PeerInfo) {.async.} = proc connect*(s: Switch, peer: PeerInfo) {.async.} =
@ -314,9 +454,9 @@ proc dial*(s: Switch,
result = conn result = conn
let stream = await s.getMuxedStream(peer) let stream = await s.getMuxedStream(peer)
if not isNil(stream): if not isNil(stream):
trace "Connection is muxed, return muxed stream" trace "Connection is muxed, return muxed stream", oid = conn.oid
result = stream result = stream
trace "Attempting to select remote", proto = proto trace "Attempting to select remote", proto = proto, oid = conn.oid
if not await s.ms.select(result, proto): if not await s.ms.select(result, proto):
raise newException(CatchableError, "Unable to select sub-protocol " & proto) raise newException(CatchableError, "Unable to select sub-protocol " & proto)
@ -338,7 +478,6 @@ proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} =
proc handle(conn: Connection): Future[void] {.async, closure, gcsafe.} = proc handle(conn: Connection): Future[void] {.async, closure, gcsafe.} =
try: try:
try: try:
libp2p_peers.inc()
await s.upgradeIncoming(conn) # perform upgrade on incoming connection await s.upgradeIncoming(conn) # perform upgrade on incoming connection
finally: finally:
await s.cleanupConn(conn) await s.cleanupConn(conn)
@ -358,6 +497,7 @@ proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} =
if s.pubSub.isSome: if s.pubSub.isSome:
await s.pubSub.get().start() await s.pubSub.get().start()
info "started libp2p node", peer = $s.peerInfo, addrs = s.peerInfo.addrs
result = startFuts # listen for incoming connections result = startFuts # listen for incoming connections
proc stop*(s: Switch) {.async.} = proc stop*(s: Switch) {.async.} =
@ -370,9 +510,10 @@ proc stop*(s: Switch) {.async.} =
if s.pubSub.isSome: if s.pubSub.isSome:
await s.pubSub.get().stop() await s.pubSub.get().stop()
for conn in toSeq(s.connections.values): for conns in toSeq(s.connections.values):
for conn in conns:
try: try:
await s.cleanupConn(conn) await s.cleanupConn(conn.conn)
except CatchableError as exc: except CatchableError as exc:
warn "error cleaning up connections" warn "error cleaning up connections"
@ -463,8 +604,8 @@ proc newSwitch*(peerInfo: PeerInfo,
result.peerInfo = peerInfo result.peerInfo = peerInfo
result.ms = newMultistream() result.ms = newMultistream()
result.transports = transports result.transports = transports
result.connections = initTable[string, Connection]() result.connections = initTable[string, seq[ConnectionHolder]]()
result.muxed = initTable[string, Muxer]() result.muxed = initTable[string, seq[MuxerHolder]]()
result.identity = identity result.identity = identity
result.muxers = muxers result.muxers = muxers
result.secureManagers = @secureManagers result.secureManagers = @secureManagers
@ -494,11 +635,9 @@ proc newSwitch*(peerInfo: PeerInfo,
# identify it # identify it
muxer.connection.peerInfo = await s.identify(stream) muxer.connection.peerInfo = await s.identify(stream)
# store muxer for connection # store muxer and muxed connection
s.muxed[muxer.connection.peerInfo.id] = muxer await s.storeConn(muxer, Direction.In)
libp2p_peers.set(s.connections.len.int64)
# store muxed connection
s.connections[muxer.connection.peerInfo.id] = muxer.connection
muxer.connection.closeEvent.wait() muxer.connection.closeEvent.wait()
.addCallback do(udata: pointer): .addCallback do(udata: pointer):
@ -506,6 +645,7 @@ proc newSwitch*(peerInfo: PeerInfo,
# try establishing a pubsub connection # try establishing a pubsub connection
await s.subscribeToPeer(muxer.connection.peerInfo) await s.subscribeToPeer(muxer.connection.peerInfo)
except CatchableError as exc: except CatchableError as exc:
libp2p_failed_upgrade.inc() libp2p_failed_upgrade.inc()
trace "exception in muxer handler", exc = exc.msg trace "exception in muxer handler", exc = exc.msg

View File

@ -46,6 +46,7 @@ proc waitSub(sender, receiver: auto; key: string) {.async, gcsafe.} =
suite "GossipSub": suite "GossipSub":
teardown: teardown:
for tracker in testTrackers(): for tracker in testTrackers():
# echo tracker.dump()
check tracker.isLeaked() == false check tracker.isLeaked() == false
test "GossipSub validation should succeed": test "GossipSub validation should succeed":

View File

@ -189,6 +189,7 @@ suite "Interop":
check string.fromBytes(await stream.transp.readLp()) == "test 3" check string.fromBytes(await stream.transp.readLp()) == "test 3"
asyncDiscard stream.transp.writeLp("test 4") asyncDiscard stream.transp.writeLp("test 4")
testFuture.complete() testFuture.complete()
await stream.close()
await daemonNode.addHandler(protos, daemonHandler) await daemonNode.addHandler(protos, daemonHandler)
let conn = await nativeNode.dial(NativePeerInfo.init(daemonPeer.peer, let conn = await nativeNode.dial(NativePeerInfo.init(daemonPeer.peer,
@ -240,6 +241,7 @@ suite "Interop":
var line = await stream.transp.readLine() var line = await stream.transp.readLine()
check line == expect check line == expect
testFuture.complete(line) testFuture.complete(line)
await stream.close()
await daemonNode.addHandler(protos, daemonHandler) await daemonNode.addHandler(protos, daemonHandler)
let conn = await nativeNode.dial(NativePeerInfo.init(daemonPeer.peer, let conn = await nativeNode.dial(NativePeerInfo.init(daemonPeer.peer,
@ -285,9 +287,12 @@ suite "Interop":
discard await stream.transp.writeLp(test) discard await stream.transp.writeLp(test)
result = test == (await wait(testFuture, 10.secs)) result = test == (await wait(testFuture, 10.secs))
await stream.close()
await nativeNode.stop() await nativeNode.stop()
await allFutures(awaiters) await allFutures(awaiters)
await daemonNode.close() await daemonNode.close()
await sleepAsync(1.seconds)
check: check:
waitFor(runTests()) == true waitFor(runTests()) == true
@ -331,6 +336,7 @@ suite "Interop":
await wait(testFuture, 10.secs) await wait(testFuture, 10.secs)
result = true result = true
await stream.close()
await nativeNode.stop() await nativeNode.stop()
await allFutures(awaiters) await allFutures(awaiters)
await daemonNode.close() await daemonNode.close()

View File

@ -192,8 +192,8 @@ suite "Switch":
await switch2.connect(switch1.peerInfo) await switch2.connect(switch1.peerInfo)
check switch1.connections.len > 0 check switch1.connections[switch2.peerInfo.id].len > 0
check switch2.connections.len > 0 check switch2.connections[switch1.peerInfo.id].len > 0
await sleepAsync(100.millis) await sleepAsync(100.millis)
await switch2.disconnect(switch1.peerInfo) await switch2.disconnect(switch1.peerInfo)
@ -207,8 +207,8 @@ suite "Switch":
# echo connTracker.dump() # echo connTracker.dump()
# check connTracker.isLeaked() == false # check connTracker.isLeaked() == false
check switch1.connections.len == 0 check switch2.peerInfo.id notin switch1.connections
check switch2.connections.len == 0 check switch1.peerInfo.id notin switch2.connections
await allFuturesThrowing( await allFuturesThrowing(
switch1.stop(), switch1.stop(),