351 lines
13 KiB
Nim
351 lines
13 KiB
Nim
|
{.used.}
|
||
|
|
||
|
import options, bearssl, chronos
|
||
|
import stew/byteutils
|
||
|
import ../libp2p/[protocols/relay,
|
||
|
multiaddress,
|
||
|
peerinfo,
|
||
|
peerid,
|
||
|
stream/connection,
|
||
|
multistream,
|
||
|
transports/transport,
|
||
|
switch,
|
||
|
builders,
|
||
|
upgrademngrs/upgrade,
|
||
|
varint,
|
||
|
daemon/daemonapi]
|
||
|
import ./helpers
|
||
|
|
||
|
proc new(T: typedesc[RelayTransport], relay: Relay): T =
|
||
|
T.new(relay = relay, upgrader = relay.switch.transports[0].upgrader)
|
||
|
|
||
|
proc writeLp*(s: StreamTransport, msg: string | seq[byte]): Future[int] {.gcsafe.} =
|
||
|
## write lenght prefixed
|
||
|
var buf = initVBuffer()
|
||
|
buf.writeSeq(msg)
|
||
|
buf.finish()
|
||
|
result = s.write(buf.buffer)
|
||
|
|
||
|
proc readLp*(s: StreamTransport): Future[seq[byte]] {.async, gcsafe.} =
|
||
|
## read length prefixed msg
|
||
|
var
|
||
|
size: uint
|
||
|
length: int
|
||
|
res: VarintResult[void]
|
||
|
result = newSeq[byte](10)
|
||
|
|
||
|
for i in 0..<len(result):
|
||
|
await s.readExactly(addr result[i], 1)
|
||
|
res = LP.getUVarint(result.toOpenArray(0, i), length, size)
|
||
|
if res.isOk():
|
||
|
break
|
||
|
res.expect("Valid varint")
|
||
|
result.setLen(size)
|
||
|
if size > 0.uint:
|
||
|
await s.readExactly(addr result[0], int(size))
|
||
|
|
||
|
suite "Circuit Relay":
|
||
|
asyncTeardown:
|
||
|
await allFutures(src.stop(), dst.stop(), rel.stop())
|
||
|
checkTrackers()
|
||
|
|
||
|
var
|
||
|
protos {.threadvar.}: seq[string]
|
||
|
customProto {.threadvar.}: LPProtocol
|
||
|
ma {.threadvar.}: MultiAddress
|
||
|
src {.threadvar.}: Switch
|
||
|
dst {.threadvar.}: Switch
|
||
|
rel {.threadvar.}: Switch
|
||
|
relaySrc {.threadvar.}: Relay
|
||
|
relayDst {.threadvar.}: Relay
|
||
|
relayRel {.threadvar.}: Relay
|
||
|
conn {.threadVar.}: Connection
|
||
|
msg {.threadVar.}: ProtoBuffer
|
||
|
rcv {.threadVar.}: Option[RelayMessage]
|
||
|
|
||
|
proc createMsg(
|
||
|
msgType: Option[RelayType] = RelayType.none,
|
||
|
status: Option[RelayStatus] = RelayStatus.none,
|
||
|
src: Option[RelayPeer] = RelayPeer.none,
|
||
|
dst: Option[RelayPeer] = RelayPeer.none): ProtoBuffer =
|
||
|
encodeMsg(RelayMessage(msgType: msgType, srcPeer: src, dstPeer: dst, status: status))
|
||
|
|
||
|
proc checkMsg(msg: Option[RelayMessage],
|
||
|
msgType: Option[RelayType] = none[RelayType](),
|
||
|
status: Option[RelayStatus] = none[RelayStatus](),
|
||
|
src: Option[RelayPeer] = none[RelayPeer](),
|
||
|
dst: Option[RelayPeer] = none[RelayPeer]()): bool =
|
||
|
msg.isSome and msg.get == RelayMessage(msgType: msgType, srcPeer: src, dstPeer: dst, status: status)
|
||
|
|
||
|
proc customHandler(conn: Connection, proto: string) {.async.} =
|
||
|
check "line1" == string.fromBytes(await conn.readLp(1024))
|
||
|
await conn.writeLp("line2")
|
||
|
check "line3" == string.fromBytes(await conn.readLp(1024))
|
||
|
await conn.writeLp("line4")
|
||
|
await conn.close()
|
||
|
|
||
|
asyncSetup:
|
||
|
# Create a custom prototype
|
||
|
protos = @[ "/customProto", RelayCodec ]
|
||
|
customProto = new LPProtocol
|
||
|
customProto.handler = customHandler
|
||
|
customProto.codec = protos[0]
|
||
|
ma = MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()
|
||
|
|
||
|
src = newStandardSwitch()
|
||
|
rel = newStandardSwitch()
|
||
|
dst = SwitchBuilder
|
||
|
.new()
|
||
|
.withRng(newRng())
|
||
|
.withAddresses(@[ ma ])
|
||
|
.withTcpTransport()
|
||
|
.withMplex()
|
||
|
.withNoise()
|
||
|
.build()
|
||
|
|
||
|
relaySrc = Relay.new(src, false)
|
||
|
relayDst = Relay.new(dst, false)
|
||
|
relayRel = Relay.new(rel, true)
|
||
|
|
||
|
src.mount(relaySrc)
|
||
|
dst.mount(relayDst)
|
||
|
dst.mount(customProto)
|
||
|
rel.mount(relayRel)
|
||
|
|
||
|
src.addTransport(RelayTransport.new(relaySrc))
|
||
|
dst.addTransport(RelayTransport.new(relayDst))
|
||
|
|
||
|
await src.start()
|
||
|
await dst.start()
|
||
|
await rel.start()
|
||
|
|
||
|
asyncTest "Handle CanHop":
|
||
|
msg = createMsg(some(CanHop))
|
||
|
conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec)
|
||
|
await conn.writeLp(msg.buffer)
|
||
|
rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize))
|
||
|
check rcv.checkMsg(some(Status), some(RelayStatus.Success))
|
||
|
|
||
|
conn = await src.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayCodec)
|
||
|
await conn.writeLp(msg.buffer)
|
||
|
rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize))
|
||
|
check rcv.checkMsg(some(Status), some(HopCantSpeakRelay))
|
||
|
|
||
|
await conn.close()
|
||
|
|
||
|
asyncTest "Malformed":
|
||
|
conn = await rel.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayCodec)
|
||
|
msg = createMsg(some(RelayType.Status))
|
||
|
await conn.writeLp(msg.buffer)
|
||
|
rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize))
|
||
|
await conn.close()
|
||
|
check rcv.checkMsg(some(Status), some(MalformedMessage))
|
||
|
|
||
|
asyncTest "Handle Stop Error":
|
||
|
conn = await rel.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayCodec)
|
||
|
msg = createMsg(some(RelayType.Stop),
|
||
|
none(RelayStatus),
|
||
|
none(RelayPeer),
|
||
|
some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs)))
|
||
|
await conn.writeLp(msg.buffer)
|
||
|
rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize))
|
||
|
check rcv.checkMsg(some(Status), some(StopSrcMultiaddrInvalid))
|
||
|
|
||
|
conn = await rel.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayCodec)
|
||
|
msg = createMsg(some(RelayType.Stop),
|
||
|
none(RelayStatus),
|
||
|
some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)),
|
||
|
none(RelayPeer))
|
||
|
await conn.writeLp(msg.buffer)
|
||
|
rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize))
|
||
|
check rcv.checkMsg(some(Status), some(StopDstMultiaddrInvalid))
|
||
|
|
||
|
conn = await rel.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayCodec)
|
||
|
msg = createMsg(some(RelayType.Stop),
|
||
|
none(RelayStatus),
|
||
|
some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs)),
|
||
|
some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)))
|
||
|
await conn.writeLp(msg.buffer)
|
||
|
rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize))
|
||
|
await conn.close()
|
||
|
check rcv.checkMsg(some(Status), some(StopDstMultiaddrInvalid))
|
||
|
|
||
|
asyncTest "Handle Hop Error":
|
||
|
conn = await src.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayCodec)
|
||
|
msg = createMsg(some(RelayType.Hop))
|
||
|
await conn.writeLp(msg.buffer)
|
||
|
rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize))
|
||
|
check rcv.checkMsg(some(Status), some(HopCantSpeakRelay))
|
||
|
|
||
|
conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec)
|
||
|
msg = createMsg(some(RelayType.Hop),
|
||
|
none(RelayStatus),
|
||
|
none(RelayPeer),
|
||
|
some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs)))
|
||
|
await conn.writeLp(msg.buffer)
|
||
|
rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize))
|
||
|
check rcv.checkMsg(some(Status), some(HopSrcMultiaddrInvalid))
|
||
|
|
||
|
conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec)
|
||
|
msg = createMsg(some(RelayType.Hop),
|
||
|
none(RelayStatus),
|
||
|
some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs)),
|
||
|
some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs)))
|
||
|
await conn.writeLp(msg.buffer)
|
||
|
rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize))
|
||
|
check rcv.checkMsg(some(Status), some(HopSrcMultiaddrInvalid))
|
||
|
|
||
|
conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec)
|
||
|
msg = createMsg(some(RelayType.Hop),
|
||
|
none(RelayStatus),
|
||
|
some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)),
|
||
|
none(RelayPeer))
|
||
|
await conn.writeLp(msg.buffer)
|
||
|
rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize))
|
||
|
check rcv.checkMsg(some(Status), some(HopDstMultiaddrInvalid))
|
||
|
|
||
|
conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec)
|
||
|
msg = createMsg(some(RelayType.Hop),
|
||
|
none(RelayStatus),
|
||
|
some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)),
|
||
|
some(RelayPeer(peerId: rel.peerInfo.peerId, addrs: rel.peerInfo.addrs)))
|
||
|
await conn.writeLp(msg.buffer)
|
||
|
rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize))
|
||
|
check rcv.checkMsg(some(Status), some(HopCantRelayToSelf))
|
||
|
|
||
|
conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec)
|
||
|
msg = createMsg(some(RelayType.Hop),
|
||
|
none(RelayStatus),
|
||
|
some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)),
|
||
|
some(RelayPeer(peerId: rel.peerInfo.peerId, addrs: rel.peerInfo.addrs)))
|
||
|
await conn.writeLp(msg.buffer)
|
||
|
rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize))
|
||
|
check rcv.checkMsg(some(Status), some(HopCantRelayToSelf))
|
||
|
|
||
|
conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec)
|
||
|
msg = createMsg(some(RelayType.Hop),
|
||
|
none(RelayStatus),
|
||
|
some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)),
|
||
|
some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs)))
|
||
|
await conn.writeLp(msg.buffer)
|
||
|
rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize))
|
||
|
check rcv.checkMsg(some(Status), some(HopNoConnToDst))
|
||
|
|
||
|
await rel.connect(dst.peerInfo.peerId, dst.peerInfo.addrs)
|
||
|
|
||
|
relayRel.maxCircuit = 0
|
||
|
conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec)
|
||
|
await conn.writeLp(msg.buffer)
|
||
|
rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize))
|
||
|
check rcv.checkMsg(some(Status), some(HopCantSpeakRelay))
|
||
|
relayRel.maxCircuit = relay.MaxCircuit
|
||
|
await conn.close()
|
||
|
|
||
|
relayRel.maxCircuitPerPeer = 0
|
||
|
conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec)
|
||
|
await conn.writeLp(msg.buffer)
|
||
|
rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize))
|
||
|
check rcv.checkMsg(some(Status), some(HopCantSpeakRelay))
|
||
|
relayRel.maxCircuitPerPeer = relay.MaxCircuitPerPeer
|
||
|
await conn.close()
|
||
|
|
||
|
let dst2 = newStandardSwitch()
|
||
|
await dst2.start()
|
||
|
await rel.connect(dst2.peerInfo.peerId, dst2.peerInfo.addrs)
|
||
|
|
||
|
conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec)
|
||
|
msg = createMsg(some(RelayType.Hop),
|
||
|
none(RelayStatus),
|
||
|
some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)),
|
||
|
some(RelayPeer(peerId: dst2.peerInfo.peerId, addrs: dst2.peerInfo.addrs)))
|
||
|
await conn.writeLp(msg.buffer)
|
||
|
rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize))
|
||
|
check rcv.checkMsg(some(Status), some(HopCantDialDst))
|
||
|
await allFutures(dst2.stop())
|
||
|
|
||
|
asyncTest "Dial Peer":
|
||
|
let maStr = $rel.peerInfo.addrs[0] & "/p2p/" & $rel.peerInfo.peerId & "/p2p-circuit/p2p/" & $dst.peerInfo.peerId
|
||
|
let maddr = MultiAddress.init(maStr).tryGet()
|
||
|
await src.connect(rel.peerInfo.peerId, rel.peerInfo.addrs)
|
||
|
await rel.connect(dst.peerInfo.peerId, dst.peerInfo.addrs)
|
||
|
conn = await src.dial(dst.peerInfo.peerId, @[ maddr ], protos[0])
|
||
|
|
||
|
await conn.writeLp("line1")
|
||
|
check string.fromBytes(await conn.readLp(1024)) == "line2"
|
||
|
|
||
|
await conn.writeLp("line3")
|
||
|
check string.fromBytes(await conn.readLp(1024)) == "line4"
|
||
|
|
||
|
asyncTest "SwitchBuilder withRelay":
|
||
|
let
|
||
|
maSrc = MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()
|
||
|
maRel = MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()
|
||
|
maDst = MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()
|
||
|
srcWR = SwitchBuilder.new()
|
||
|
.withRng(newRng())
|
||
|
.withAddresses(@[ maSrc ])
|
||
|
.withTcpTransport()
|
||
|
.withMplex()
|
||
|
.withNoise()
|
||
|
.withRelayTransport(false)
|
||
|
.build()
|
||
|
relWR = SwitchBuilder.new()
|
||
|
.withRng(newRng())
|
||
|
.withAddresses(@[ maRel ])
|
||
|
.withTcpTransport()
|
||
|
.withMplex()
|
||
|
.withNoise()
|
||
|
.withRelayTransport(true)
|
||
|
.build()
|
||
|
dstWR = SwitchBuilder.new()
|
||
|
.withRng(newRng())
|
||
|
.withAddresses(@[ maDst ])
|
||
|
.withTcpTransport()
|
||
|
.withMplex()
|
||
|
.withNoise()
|
||
|
.withRelayTransport(false)
|
||
|
.build()
|
||
|
|
||
|
dstWR.mount(customProto)
|
||
|
|
||
|
await srcWR.start()
|
||
|
await dstWR.start()
|
||
|
await relWR.start()
|
||
|
|
||
|
let maStr = $relWR.peerInfo.addrs[0] & "/p2p/" & $relWR.peerInfo.peerId & "/p2p-circuit/p2p/" & $dstWR.peerInfo.peerId
|
||
|
let maddr = MultiAddress.init(maStr).tryGet()
|
||
|
await srcWR.connect(relWR.peerInfo.peerId, relWR.peerInfo.addrs)
|
||
|
await relWR.connect(dstWR.peerInfo.peerId, dstWR.peerInfo.addrs)
|
||
|
conn = await srcWR.dial(dstWR.peerInfo.peerId, @[ maddr ], protos[0])
|
||
|
|
||
|
await conn.writeLp("line1")
|
||
|
check string.fromBytes(await conn.readLp(1024)) == "line2"
|
||
|
|
||
|
await conn.writeLp("line3")
|
||
|
check string.fromBytes(await conn.readLp(1024)) == "line4"
|
||
|
|
||
|
await allFutures(srcWR.stop(), dstWR.stop(), relWR.stop())
|
||
|
|
||
|
asynctest "Bad MultiAddress":
|
||
|
await src.connect(rel.peerInfo.peerId, rel.peerInfo.addrs)
|
||
|
await rel.connect(dst.peerInfo.peerId, dst.peerInfo.addrs)
|
||
|
expect(CatchableError):
|
||
|
let maStr = $rel.peerInfo.addrs[0] & "/p2p/" & $rel.peerInfo.peerId & "/p2p/" & $dst.peerInfo.peerId
|
||
|
let maddr = MultiAddress.init(maStr).tryGet()
|
||
|
conn = await src.dial(dst.peerInfo.peerId, @[ maddr ], protos[0])
|
||
|
|
||
|
expect(CatchableError):
|
||
|
let maStr = $rel.peerInfo.addrs[0] & "/p2p/" & $rel.peerInfo.peerId
|
||
|
let maddr = MultiAddress.init(maStr).tryGet()
|
||
|
conn = await src.dial(dst.peerInfo.peerId, @[ maddr ], protos[0])
|
||
|
|
||
|
expect(CatchableError):
|
||
|
let maStr = "/ip4/127.0.0.1"
|
||
|
let maddr = MultiAddress.init(maStr).tryGet()
|
||
|
conn = await src.dial(dst.peerInfo.peerId, @[ maddr ], protos[0])
|
||
|
|
||
|
expect(CatchableError):
|
||
|
let maStr = $dst.peerInfo.peerId
|
||
|
let maddr = MultiAddress.init(maStr).tryGet()
|
||
|
conn = await src.dial(dst.peerInfo.peerId, @[ maddr ], protos[0])
|