handle race condition for incoming and outgoing connections

This commit is contained in:
Dmitriy Ryajov 2020-02-19 10:33:12 -06:00
parent 1a987a9c5b
commit bf8d7a85e1
2 changed files with 87 additions and 6 deletions

View File

@ -175,19 +175,43 @@ proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, g
return
await s.mux(result) # mux it if possible
s.connections[conn.peerInfo.id] = result
if result.peerInfo.id notin s.connections:
s.connections[result.peerInfo.id] = result
proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} =
## Upgrade incoming connections, this roughly looks like:
## - First, register all the secure handlers and await for
## a secure request
##
## - Next, when a secure request arrives, handle it in the
## catch all ``securedHandler`` handler
##
## - Lastly, register muxers and handler subsequent muxer
## requests
##
trace "upgrading incoming connection"
let ms = newMultistream()
# secure incoming connections
proc securedHandler (conn: Connection,
proto: string)
proc securedHandler (conn: Connection, proto: string)
{.async, gcsafe, closure.} =
## generic handler for secure managers
trace "Securing connection"
# get the secure handler for the proto
let secure = s.secureManagers[proto]
let sconn = await secure.secure(conn)
# if the connection has been already
# established while negotiating this
# one we drop it
if sconn.peerInfo.id in s.connections:
await sconn.close()
return
s.connections[sconn.peerInfo.id] = sconn
# if securing succedded, handle muxer requests
if not isNil(sconn):
# add the muxer
for muxer in s.muxers.values:
@ -197,7 +221,8 @@ proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} =
await ms.handle(sconn)
if (await ms.select(conn)): # just handshake
# add the secure handlers
# register all the secure managers to be
# handled by the catch all ``securedHandler``
for k in s.secureManagers.keys:
ms.addHandler(k, securedHandler)
@ -217,8 +242,17 @@ proc dial*(s: Switch,
if t.handles(a): # check if it can dial it
trace "Dialing address", address = $a
conn = await t.dial(a)
# avoid raicing with incoming connections
if peer.id in s.connections:
if not isNil(conn) and not conn.closed():
await conn.close()
conn = s.connections[peer.id]
# make sure to assign the peer to the connection
conn.peerInfo = peer
if isNil(conn.peerInfo):
conn.peerInfo = peer
conn = await s.upgradeOutgoing(conn)
if isNil(conn):
continue
@ -234,7 +268,8 @@ proc dial*(s: Switch,
raise newException(CatchableError, "Unable to establish outgoing link")
if proto.len > 0 and not conn.closed:
let stream = await s.getMuxedStream(peer)
result = conn
var stream = await s.getMuxedStream(peer)
if not isNil(stream):
trace "Connection is muxed, return muxed stream"
result = stream

View File

@ -106,3 +106,49 @@ suite "Switch":
check:
waitFor(testSwitch()) == true
test "e2e use switch nested dial":
proc testSwitch(): Future[bool] {.async, gcsafe.} =
let ma1: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0")
let ma2: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0")
var peerInfo1, peerInfo2: PeerInfo
var switch1, switch2: Switch
(switch1, peerInfo1) = createSwitch(ma1)
var awaiters: seq[Future[void]]
awaiters.add(await switch1.start())
(switch2, peerInfo2) = createSwitch(ma2)
awaiters.add(await switch2.start())
var proto1 = new LPProtocol
proto1.codec = "/proto/1"
var awaiter = newFuture[void]()
proc handler1(conn: Connection, proto: string) {.async, gcsafe.} =
var nested = await switch1.dial(switch2.peerInfo, "/proto/2")
await nested.writeLp("proto 1")
check cast[string](await nested.readLp()) == "proto 2"
await nested.close()
awaiter.complete()
proto1.handler = handler1
switch1.mount(proto1)
var proto2 = new LPProtocol
proto2.codec = "/proto/2"
proc handler2(conn: Connection, proto: string) {.async, gcsafe.} =
check cast[string](await conn.readLp()) == "proto 1"
await conn.writeLp("proto 2")
await conn.close()
proto2.handler = handler2
switch2.mount(proto2)
discard await switch1.dial(switch1.peerInfo, "/proto/1")
await awaiter
discard allFutures(switch1.stop(), switch2.stop())
await allFutures(awaiters)
result = true
check:
waitFor(testSwitch()) == true