more consistent dialing proto selecting logic

This commit is contained in:
Dmitriy Ryajov 2020-06-01 11:49:25 -06:00
parent 4c191866e4
commit abf659a01a
3 changed files with 22 additions and 32 deletions

View File

@ -71,12 +71,13 @@ proc select*(m: MultistreamSelect,
trace "reading first requested proto" trace "reading first requested proto"
result.removeSuffix("\n") result.removeSuffix("\n")
if result == proto[0]: if result == proto[0]:
trace "succesfully selected ", proto = proto trace "successfully selected ", proto = proto
return return
if not result.len > 0: let protos = proto[1..<proto.len()]
trace "selecting one of several protos" trace "selecting one of several protos", protos = protos
for p in proto[1..<proto.len()]: for p in protos:
trace "selecting proto", proto = p
await conn.writeLp((p & "\n")) # select proto await conn.writeLp((p & "\n")) # select proto
result = string.fromBytes(await conn.readLp(1024)) # read the first proto result = string.fromBytes(await conn.readLp(1024)) # read the first proto
result.removeSuffix("\n") result.removeSuffix("\n")
@ -157,7 +158,7 @@ proc addHandler*[T: LPProtocol](m: MultistreamSelect,
matcher: Matcher = nil) = matcher: Matcher = nil) =
## register a protocol ## register a protocol
# TODO: This is a bug in chronicles, # TODO: This is a bug in chronicles,
# it break if I uncoment this line. # it break if I uncomment this line.
# Which is almost the same as the # Which is almost the same as the
# one on the next override of addHandler # one on the next override of addHandler
# #

View File

@ -44,7 +44,7 @@ type
ms*: MultistreamSelect ms*: MultistreamSelect
identity*: Identify identity*: Identify
streamHandler*: StreamHandler streamHandler*: StreamHandler
secureManagers*: OrderedTable[string, Secure] secureManagers*: seq[Secure]
pubSub*: Option[PubSub] pubSub*: Option[PubSub]
dialedPubSubPeers: HashSet[string] dialedPubSubPeers: HashSet[string]
@ -52,17 +52,14 @@ proc newNoPubSubException(): ref CatchableError {.inline.} =
result = newException(NoPubSubException, "no pubsub provided!") result = newException(NoPubSubException, "no pubsub provided!")
proc secure(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} = proc secure(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} =
## secure the incoming connection if s.secureManagers.len <= 0:
let managers = toSeq(s.secureManagers.keys)
if managers.len == 0:
raise newException(CatchableError, "No secure managers registered!") raise newException(CatchableError, "No secure managers registered!")
let manager = await s.ms.select(conn, toSeq(s.secureManagers.values).mapIt(it.codec)) let manager = await s.ms.select(conn, s.secureManagers.mapIt(it.codec))
if manager.len == 0: if manager.len == 0:
raise newException(CatchableError, "Unable to negotiate a secure channel!") raise newException(CatchableError, "Unable to negotiate a secure channel!")
result = await s.secureManagers[manager].secure(conn, true) result = await s.secureManagers.filterIt(it.codec == manager)[0].secure(conn, true)
proc identify(s: Switch, conn: Connection): Future[PeerInfo] {.async, gcsafe.} = proc identify(s: Switch, conn: Connection): Future[PeerInfo] {.async, gcsafe.} =
## identify the connection ## identify the connection
@ -194,7 +191,7 @@ proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} =
{.async, gcsafe, closure.} = {.async, gcsafe, closure.} =
try: try:
trace "Securing connection" trace "Securing connection"
let secure = s.secureManagers[proto] let secure = s.secureManagers.filterIt(it.codec == proto)[0]
let sconn = await secure.secure(conn, false) let sconn = await secure.secure(conn, false)
if sconn.isNil: if sconn.isNil:
return return
@ -218,8 +215,8 @@ proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} =
try: try:
if (await ms.select(conn)): # just handshake if (await ms.select(conn)): # just handshake
# add the secure handlers # add the secure handlers
for k in s.secureManagers.keys: for k in s.secureManagers:
ms.addHandler(k, securedHandler) ms.addHandler(k.codec, securedHandler)
# handle secured connections # handle secured connections
await ms.handle(conn) await ms.handle(conn)
@ -428,7 +425,7 @@ proc newSwitch*(peerInfo: PeerInfo,
result.muxed = initTable[string, Muxer]() result.muxed = initTable[string, Muxer]()
result.identity = identity result.identity = identity
result.muxers = muxers result.muxers = muxers
result.secureManagers = initOrderedTable[string, Secure]() result.secureManagers = @secureManagers
result.dialedPubSubPeers = initHashSet[string]() result.dialedPubSubPeers = initHashSet[string]()
let s = result # can't capture result let s = result # can't capture result
@ -467,14 +464,10 @@ proc newSwitch*(peerInfo: PeerInfo,
if not(isNil(stream)): if not(isNil(stream)):
await stream.close() await stream.close()
for proto in secureManagers: if result.secureManagers.len <= 0:
trace "adding secure manager ", codec = proto.codec
result.secureManagers[proto.codec] = proto
if result.secureManagers.len == 0:
# use plain text if no secure managers are provided # use plain text if no secure managers are provided
warn "no secure managers, falling back to plain text", codec = PlainTextCodec warn "no secure managers, falling back to plain text", codec = PlainTextCodec
result.secureManagers[PlainTextCodec] = Secure(newPlainText()) result.secureManagers &= Secure(newPlainText())
if pubSub.isSome: if pubSub.isSome:
result.pubSub = pubSub result.pubSub = pubSub

View File

@ -83,7 +83,6 @@ proc testPubSubDaemonPublish(gossip: bool = false,
let smsg = cast[string](data) let smsg = cast[string](data)
check smsg == pubsubData check smsg == pubsubData
times.inc() times.inc()
echo "TIMES ", times
if times >= count and not finished: if times >= count and not finished:
finished = true finished = true
@ -108,7 +107,6 @@ proc testPubSubDaemonPublish(gossip: bool = false,
await wait(publisher(), 5.minutes) # should be plenty of time await wait(publisher(), 5.minutes) # should be plenty of time
echo "HEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEE"
result = true result = true
await nativeNode.stop() await nativeNode.stop()
await allFutures(awaiters) await allFutures(awaiters)
@ -144,7 +142,6 @@ proc testPubSubNodePublish(gossip: bool = false,
let smsg = cast[string](message.data) let smsg = cast[string](message.data)
check smsg == pubsubData check smsg == pubsubData
times.inc() times.inc()
echo "TIMES ", times
if times >= count and not finished: if times >= count and not finished:
finished = true finished = true
result = true # don't cancel subscription result = true # don't cancel subscription
@ -356,7 +353,6 @@ suite "Interop":
check line == test check line == test
await conn.writeLp(cast[seq[byte]](test)) await conn.writeLp(cast[seq[byte]](test))
count.inc() count.inc()
echo "COUNT ", count
testFuture.complete(count) testFuture.complete(count)
await conn.close() await conn.close()