allow selecting one of many protos in identify
This commit is contained in:
parent
a7e5fde6f7
commit
cc595f7947
|
@ -40,25 +40,37 @@ proc newMultistream*(): MultisteamSelect =
|
||||||
|
|
||||||
proc select*(m: MultisteamSelect,
|
proc select*(m: MultisteamSelect,
|
||||||
conn: Connection,
|
conn: Connection,
|
||||||
proto: string = ""): Future[bool] {.async.} =
|
proto: seq[string]): Future[bool] {.async.} =
|
||||||
## select a remote protocol
|
## select a remote protocol
|
||||||
## TODO: select should support a list of protos to be selected
|
|
||||||
|
|
||||||
await conn.write(m.codec) # write handshake
|
await conn.write(m.codec) # write handshake
|
||||||
if proto.len() > 0:
|
if proto.len() > 0:
|
||||||
await conn.writeLp(proto) # select proto
|
await conn.writeLp(proto[0]) # select proto
|
||||||
|
|
||||||
var ms = cast[string](await conn.readLp())
|
var ms = cast[string](await conn.readLp()) # read ms header
|
||||||
ms.removeSuffix("\n")
|
ms.removeSuffix("\n")
|
||||||
if ms != Codec:
|
if ms != Codec:
|
||||||
return false
|
return false
|
||||||
|
|
||||||
if proto.len() <= 0:
|
if proto.len() == 0: # no protocols, must be a handshake call
|
||||||
return true
|
return true
|
||||||
|
|
||||||
ms = cast[string](await conn.readLp())
|
ms = cast[string](await conn.readLp()) # read the first proto
|
||||||
ms.removeSuffix("\n")
|
ms.removeSuffix("\n")
|
||||||
result = ms == proto
|
result = ms == proto[0]
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
for p in proto[1..<proto.len()]:
|
||||||
|
await conn.writeLp(p) # select proto
|
||||||
|
ms = cast[string](await conn.readLp()) # read the first proto
|
||||||
|
ms.removeSuffix("\n")
|
||||||
|
result = ms == p
|
||||||
|
if result:
|
||||||
|
break
|
||||||
|
|
||||||
|
proc select*(m: MultisteamSelect,
|
||||||
|
conn: Connection,
|
||||||
|
proto: string = ""): Future[bool] =
|
||||||
|
result = if proto.len > 0: m.select(conn, @[proto]) else: m.select(conn, @[])
|
||||||
|
|
||||||
proc list*(m: MultisteamSelect,
|
proc list*(m: MultisteamSelect,
|
||||||
conn: Connection): Future[seq[string]] {.async.} =
|
conn: Connection): Future[seq[string]] {.async.} =
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import unittest, strutils, sequtils, sugar
|
import unittest, strutils, sequtils, sugar, strformat
|
||||||
import chronos
|
import chronos
|
||||||
import ../libp2p/connection, ../libp2p/multistream,
|
import ../libp2p/connection, ../libp2p/multistream,
|
||||||
../libp2p/stream/lpstream, ../libp2p/connection,
|
../libp2p/stream/lpstream, ../libp2p/connection,
|
||||||
|
@ -143,7 +143,7 @@ suite "Multistream select":
|
||||||
proc testSelect(): Future[bool] {.async.} =
|
proc testSelect(): Future[bool] {.async.} =
|
||||||
let ms = newMultistream()
|
let ms = newMultistream()
|
||||||
let conn = newConnection(newTestSelectStream())
|
let conn = newConnection(newTestSelectStream())
|
||||||
result = await ms.select(conn, "/test/proto/1.0.0")
|
result = await ms.select(conn, @["/test/proto/1.0.0"])
|
||||||
|
|
||||||
check:
|
check:
|
||||||
waitFor(testSelect()) == true
|
waitFor(testSelect()) == true
|
||||||
|
@ -255,7 +255,7 @@ suite "Multistream select":
|
||||||
let transport2: TcpTransport = newTransport(TcpTransport)
|
let transport2: TcpTransport = newTransport(TcpTransport)
|
||||||
let conn = await transport2.dial(ma)
|
let conn = await transport2.dial(ma)
|
||||||
|
|
||||||
let res = await msDial.select(conn, "/test/proto/1.0.0")
|
let res = await msDial.select(conn, @["/test/proto/1.0.0"])
|
||||||
check res == true
|
check res == true
|
||||||
|
|
||||||
let hello = cast[string](await conn.readLp())
|
let hello = cast[string](await conn.readLp())
|
||||||
|
@ -298,3 +298,81 @@ suite "Multistream select":
|
||||||
|
|
||||||
check:
|
check:
|
||||||
waitFor(endToEnd()) == true
|
waitFor(endToEnd()) == true
|
||||||
|
|
||||||
|
test "e2e - select one of one invalid":
|
||||||
|
proc endToEnd(): Future[bool] {.async.} =
|
||||||
|
let ma: MultiAddress = Multiaddress.init("/ip4/127.0.0.1/tcp/53352")
|
||||||
|
|
||||||
|
let seckey = PrivateKey.random(RSA)
|
||||||
|
var peerInfo: PeerInfo
|
||||||
|
peerInfo.peerId = PeerID.init(seckey)
|
||||||
|
var protocol: LPProtocol = new LPProtocol
|
||||||
|
proc testHandler(conn: Connection,
|
||||||
|
proto: string):
|
||||||
|
Future[void] {.async, gcsafe.} =
|
||||||
|
check proto == "/test/proto/1.0.0"
|
||||||
|
await conn.writeLp("Hello!")
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
|
protocol.handler = testHandler
|
||||||
|
let msListen = newMultistream()
|
||||||
|
msListen.addHandler("/test/proto/1.0.0", protocol)
|
||||||
|
|
||||||
|
proc connHandler(conn: Connection): Future[void] {.async, gcsafe.} =
|
||||||
|
await msListen.handle(conn)
|
||||||
|
|
||||||
|
let transport1: TcpTransport = newTransport(TcpTransport)
|
||||||
|
await transport1.listen(ma, connHandler)
|
||||||
|
|
||||||
|
let msDial = newMultistream()
|
||||||
|
let transport2: TcpTransport = newTransport(TcpTransport)
|
||||||
|
let conn = await transport2.dial(ma)
|
||||||
|
|
||||||
|
let res = await msDial.select(conn, @["/test/proto/1.0.0", "/test/no/proto/1.0.0"])
|
||||||
|
check res == true
|
||||||
|
|
||||||
|
let hello = cast[string](await conn.readLp())
|
||||||
|
result = hello == "Hello!"
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
|
check:
|
||||||
|
waitFor(endToEnd()) == true
|
||||||
|
|
||||||
|
test "e2e - select one with both valid":
|
||||||
|
proc endToEnd(): Future[bool] {.async.} =
|
||||||
|
let ma: MultiAddress = Multiaddress.init("/ip4/127.0.0.1/tcp/53353")
|
||||||
|
|
||||||
|
let seckey = PrivateKey.random(RSA)
|
||||||
|
var peerInfo: PeerInfo
|
||||||
|
peerInfo.peerId = PeerID.init(seckey)
|
||||||
|
var protocol: LPProtocol = new LPProtocol
|
||||||
|
proc testHandler(conn: Connection,
|
||||||
|
proto: string):
|
||||||
|
Future[void] {.async, gcsafe.} =
|
||||||
|
await conn.writeLp(&"Hello from {proto}!")
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
|
protocol.handler = testHandler
|
||||||
|
let msListen = newMultistream()
|
||||||
|
msListen.addHandler("/test/proto1/1.0.0", protocol)
|
||||||
|
msListen.addHandler("/test/proto2/1.0.0", protocol)
|
||||||
|
|
||||||
|
proc connHandler(conn: Connection): Future[void] {.async, gcsafe.} =
|
||||||
|
await msListen.handle(conn)
|
||||||
|
|
||||||
|
let transport1: TcpTransport = newTransport(TcpTransport)
|
||||||
|
await transport1.listen(ma, connHandler)
|
||||||
|
|
||||||
|
let msDial = newMultistream()
|
||||||
|
let transport2: TcpTransport = newTransport(TcpTransport)
|
||||||
|
let conn = await transport2.dial(ma)
|
||||||
|
|
||||||
|
let res = await msDial.select(conn, @["/test/proto2/1.0.0", "/test/proto1/1.0.0"])
|
||||||
|
check res == true
|
||||||
|
|
||||||
|
let hello = cast[string](await conn.readLp())
|
||||||
|
result = hello == "Hello from /test/proto2/1.0.0!"
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
|
check:
|
||||||
|
waitFor(endToEnd()) == true
|
Loading…
Reference in New Issue