handle race condition for incoming and outgoing connections
This commit is contained in:
parent
1a987a9c5b
commit
bf8d7a85e1
|
@ -175,19 +175,43 @@ proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, g
|
||||||
return
|
return
|
||||||
|
|
||||||
await s.mux(result) # mux it if possible
|
await s.mux(result) # mux it if possible
|
||||||
s.connections[conn.peerInfo.id] = result
|
if result.peerInfo.id notin s.connections:
|
||||||
|
s.connections[result.peerInfo.id] = result
|
||||||
|
|
||||||
proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} =
|
proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} =
|
||||||
|
## Upgrade incoming connections, this roughly looks like:
|
||||||
|
## - First, register all the secure handlers and await for
|
||||||
|
## a secure request
|
||||||
|
##
|
||||||
|
## - Next, when a secure request arrives, handle it in the
|
||||||
|
## catch all ``securedHandler`` handler
|
||||||
|
##
|
||||||
|
## - Lastly, register muxers and handler subsequent muxer
|
||||||
|
## requests
|
||||||
|
##
|
||||||
trace "upgrading incoming connection"
|
trace "upgrading incoming connection"
|
||||||
let ms = newMultistream()
|
let ms = newMultistream()
|
||||||
|
|
||||||
# secure incoming connections
|
# secure incoming connections
|
||||||
proc securedHandler (conn: Connection,
|
proc securedHandler (conn: Connection, proto: string)
|
||||||
proto: string)
|
|
||||||
{.async, gcsafe, closure.} =
|
{.async, gcsafe, closure.} =
|
||||||
|
## generic handler for secure managers
|
||||||
trace "Securing connection"
|
trace "Securing connection"
|
||||||
|
|
||||||
|
# get the secure handler for the proto
|
||||||
let secure = s.secureManagers[proto]
|
let secure = s.secureManagers[proto]
|
||||||
let sconn = await secure.secure(conn)
|
let sconn = await secure.secure(conn)
|
||||||
|
|
||||||
|
# if the connection has been already
|
||||||
|
# established while negotiating this
|
||||||
|
# one we drop it
|
||||||
|
if sconn.peerInfo.id in s.connections:
|
||||||
|
await sconn.close()
|
||||||
|
return
|
||||||
|
|
||||||
|
s.connections[sconn.peerInfo.id] = sconn
|
||||||
|
|
||||||
|
# if securing succedded, handle muxer requests
|
||||||
if not isNil(sconn):
|
if not isNil(sconn):
|
||||||
# add the muxer
|
# add the muxer
|
||||||
for muxer in s.muxers.values:
|
for muxer in s.muxers.values:
|
||||||
|
@ -197,7 +221,8 @@ proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} =
|
||||||
await ms.handle(sconn)
|
await ms.handle(sconn)
|
||||||
|
|
||||||
if (await ms.select(conn)): # just handshake
|
if (await ms.select(conn)): # just handshake
|
||||||
# add the secure handlers
|
# register all the secure managers to be
|
||||||
|
# handled by the catch all ``securedHandler``
|
||||||
for k in s.secureManagers.keys:
|
for k in s.secureManagers.keys:
|
||||||
ms.addHandler(k, securedHandler)
|
ms.addHandler(k, securedHandler)
|
||||||
|
|
||||||
|
@ -217,8 +242,17 @@ proc dial*(s: Switch,
|
||||||
if t.handles(a): # check if it can dial it
|
if t.handles(a): # check if it can dial it
|
||||||
trace "Dialing address", address = $a
|
trace "Dialing address", address = $a
|
||||||
conn = await t.dial(a)
|
conn = await t.dial(a)
|
||||||
|
|
||||||
|
# avoid raicing with incoming connections
|
||||||
|
if peer.id in s.connections:
|
||||||
|
if not isNil(conn) and not conn.closed():
|
||||||
|
await conn.close()
|
||||||
|
conn = s.connections[peer.id]
|
||||||
|
|
||||||
# make sure to assign the peer to the connection
|
# make sure to assign the peer to the connection
|
||||||
conn.peerInfo = peer
|
if isNil(conn.peerInfo):
|
||||||
|
conn.peerInfo = peer
|
||||||
|
|
||||||
conn = await s.upgradeOutgoing(conn)
|
conn = await s.upgradeOutgoing(conn)
|
||||||
if isNil(conn):
|
if isNil(conn):
|
||||||
continue
|
continue
|
||||||
|
@ -234,7 +268,8 @@ proc dial*(s: Switch,
|
||||||
raise newException(CatchableError, "Unable to establish outgoing link")
|
raise newException(CatchableError, "Unable to establish outgoing link")
|
||||||
|
|
||||||
if proto.len > 0 and not conn.closed:
|
if proto.len > 0 and not conn.closed:
|
||||||
let stream = await s.getMuxedStream(peer)
|
result = conn
|
||||||
|
var stream = await s.getMuxedStream(peer)
|
||||||
if not isNil(stream):
|
if not isNil(stream):
|
||||||
trace "Connection is muxed, return muxed stream"
|
trace "Connection is muxed, return muxed stream"
|
||||||
result = stream
|
result = stream
|
||||||
|
|
|
@ -106,3 +106,49 @@ suite "Switch":
|
||||||
|
|
||||||
check:
|
check:
|
||||||
waitFor(testSwitch()) == true
|
waitFor(testSwitch()) == true
|
||||||
|
|
||||||
|
test "e2e use switch nested dial":
|
||||||
|
proc testSwitch(): Future[bool] {.async, gcsafe.} =
|
||||||
|
let ma1: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0")
|
||||||
|
let ma2: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0")
|
||||||
|
|
||||||
|
var peerInfo1, peerInfo2: PeerInfo
|
||||||
|
var switch1, switch2: Switch
|
||||||
|
(switch1, peerInfo1) = createSwitch(ma1)
|
||||||
|
var awaiters: seq[Future[void]]
|
||||||
|
awaiters.add(await switch1.start())
|
||||||
|
|
||||||
|
(switch2, peerInfo2) = createSwitch(ma2)
|
||||||
|
awaiters.add(await switch2.start())
|
||||||
|
|
||||||
|
var proto1 = new LPProtocol
|
||||||
|
proto1.codec = "/proto/1"
|
||||||
|
var awaiter = newFuture[void]()
|
||||||
|
proc handler1(conn: Connection, proto: string) {.async, gcsafe.} =
|
||||||
|
var nested = await switch1.dial(switch2.peerInfo, "/proto/2")
|
||||||
|
await nested.writeLp("proto 1")
|
||||||
|
check cast[string](await nested.readLp()) == "proto 2"
|
||||||
|
await nested.close()
|
||||||
|
awaiter.complete()
|
||||||
|
|
||||||
|
proto1.handler = handler1
|
||||||
|
switch1.mount(proto1)
|
||||||
|
|
||||||
|
var proto2 = new LPProtocol
|
||||||
|
proto2.codec = "/proto/2"
|
||||||
|
proc handler2(conn: Connection, proto: string) {.async, gcsafe.} =
|
||||||
|
check cast[string](await conn.readLp()) == "proto 1"
|
||||||
|
await conn.writeLp("proto 2")
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
|
proto2.handler = handler2
|
||||||
|
switch2.mount(proto2)
|
||||||
|
|
||||||
|
discard await switch1.dial(switch1.peerInfo, "/proto/1")
|
||||||
|
await awaiter
|
||||||
|
discard allFutures(switch1.stop(), switch2.stop())
|
||||||
|
await allFutures(awaiters)
|
||||||
|
result = true
|
||||||
|
|
||||||
|
check:
|
||||||
|
waitFor(testSwitch()) == true
|
||||||
|
|
Loading…
Reference in New Issue