handle race condition for incoming and outgoing connections
This commit is contained in:
parent
1a987a9c5b
commit
bf8d7a85e1
|
@ -175,19 +175,43 @@ proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, g
|
|||
return
|
||||
|
||||
await s.mux(result) # mux it if possible
|
||||
s.connections[conn.peerInfo.id] = result
|
||||
if result.peerInfo.id notin s.connections:
|
||||
s.connections[result.peerInfo.id] = result
|
||||
|
||||
proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} =
|
||||
## Upgrade incoming connections, this roughly looks like:
|
||||
## - First, register all the secure handlers and await for
|
||||
## a secure request
|
||||
##
|
||||
## - Next, when a secure request arrives, handle it in the
|
||||
## catch all ``securedHandler`` handler
|
||||
##
|
||||
## - Lastly, register muxers and handler subsequent muxer
|
||||
## requests
|
||||
##
|
||||
trace "upgrading incoming connection"
|
||||
let ms = newMultistream()
|
||||
|
||||
# secure incoming connections
|
||||
proc securedHandler (conn: Connection,
|
||||
proto: string)
|
||||
proc securedHandler (conn: Connection, proto: string)
|
||||
{.async, gcsafe, closure.} =
|
||||
## generic handler for secure managers
|
||||
trace "Securing connection"
|
||||
|
||||
# get the secure handler for the proto
|
||||
let secure = s.secureManagers[proto]
|
||||
let sconn = await secure.secure(conn)
|
||||
|
||||
# if the connection has been already
|
||||
# established while negotiating this
|
||||
# one we drop it
|
||||
if sconn.peerInfo.id in s.connections:
|
||||
await sconn.close()
|
||||
return
|
||||
|
||||
s.connections[sconn.peerInfo.id] = sconn
|
||||
|
||||
# if securing succedded, handle muxer requests
|
||||
if not isNil(sconn):
|
||||
# add the muxer
|
||||
for muxer in s.muxers.values:
|
||||
|
@ -197,7 +221,8 @@ proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} =
|
|||
await ms.handle(sconn)
|
||||
|
||||
if (await ms.select(conn)): # just handshake
|
||||
# add the secure handlers
|
||||
# register all the secure managers to be
|
||||
# handled by the catch all ``securedHandler``
|
||||
for k in s.secureManagers.keys:
|
||||
ms.addHandler(k, securedHandler)
|
||||
|
||||
|
@ -217,8 +242,17 @@ proc dial*(s: Switch,
|
|||
if t.handles(a): # check if it can dial it
|
||||
trace "Dialing address", address = $a
|
||||
conn = await t.dial(a)
|
||||
|
||||
# avoid raicing with incoming connections
|
||||
if peer.id in s.connections:
|
||||
if not isNil(conn) and not conn.closed():
|
||||
await conn.close()
|
||||
conn = s.connections[peer.id]
|
||||
|
||||
# make sure to assign the peer to the connection
|
||||
if isNil(conn.peerInfo):
|
||||
conn.peerInfo = peer
|
||||
|
||||
conn = await s.upgradeOutgoing(conn)
|
||||
if isNil(conn):
|
||||
continue
|
||||
|
@ -234,7 +268,8 @@ proc dial*(s: Switch,
|
|||
raise newException(CatchableError, "Unable to establish outgoing link")
|
||||
|
||||
if proto.len > 0 and not conn.closed:
|
||||
let stream = await s.getMuxedStream(peer)
|
||||
result = conn
|
||||
var stream = await s.getMuxedStream(peer)
|
||||
if not isNil(stream):
|
||||
trace "Connection is muxed, return muxed stream"
|
||||
result = stream
|
||||
|
|
|
@ -106,3 +106,49 @@ suite "Switch":
|
|||
|
||||
check:
|
||||
waitFor(testSwitch()) == true
|
||||
|
||||
test "e2e use switch nested dial":
|
||||
proc testSwitch(): Future[bool] {.async, gcsafe.} =
|
||||
let ma1: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0")
|
||||
let ma2: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0")
|
||||
|
||||
var peerInfo1, peerInfo2: PeerInfo
|
||||
var switch1, switch2: Switch
|
||||
(switch1, peerInfo1) = createSwitch(ma1)
|
||||
var awaiters: seq[Future[void]]
|
||||
awaiters.add(await switch1.start())
|
||||
|
||||
(switch2, peerInfo2) = createSwitch(ma2)
|
||||
awaiters.add(await switch2.start())
|
||||
|
||||
var proto1 = new LPProtocol
|
||||
proto1.codec = "/proto/1"
|
||||
var awaiter = newFuture[void]()
|
||||
proc handler1(conn: Connection, proto: string) {.async, gcsafe.} =
|
||||
var nested = await switch1.dial(switch2.peerInfo, "/proto/2")
|
||||
await nested.writeLp("proto 1")
|
||||
check cast[string](await nested.readLp()) == "proto 2"
|
||||
await nested.close()
|
||||
awaiter.complete()
|
||||
|
||||
proto1.handler = handler1
|
||||
switch1.mount(proto1)
|
||||
|
||||
var proto2 = new LPProtocol
|
||||
proto2.codec = "/proto/2"
|
||||
proc handler2(conn: Connection, proto: string) {.async, gcsafe.} =
|
||||
check cast[string](await conn.readLp()) == "proto 1"
|
||||
await conn.writeLp("proto 2")
|
||||
await conn.close()
|
||||
|
||||
proto2.handler = handler2
|
||||
switch2.mount(proto2)
|
||||
|
||||
discard await switch1.dial(switch1.peerInfo, "/proto/1")
|
||||
await awaiter
|
||||
discard allFutures(switch1.stop(), switch2.stop())
|
||||
await allFutures(awaiters)
|
||||
result = true
|
||||
|
||||
check:
|
||||
waitFor(testSwitch()) == true
|
||||
|
|
Loading…
Reference in New Issue