From 13503f3799aa7c746e17db0a4af38f8aec4115b8 Mon Sep 17 00:00:00 2001 From: lchenut Date: Wed, 18 May 2022 10:19:37 +0200 Subject: [PATCH] Circuit relay v1 (#670) Circuit relay v1 --- libp2p/builders.nim | 17 +- libp2p/dial.nim | 8 +- libp2p/dialer.nim | 3 + libp2p/multiaddress.nim | 6 +- libp2p/protocols/relay.nim | 488 +++++++++++++++++++++++++++++++++++++ libp2p/switch.nim | 9 +- tests/testinterop.nim | 157 +++++++++++- tests/testnative.nim | 3 +- tests/testrelay.nim | 350 ++++++++++++++++++++++++++ 9 files changed, 1033 insertions(+), 8 deletions(-) create mode 100644 libp2p/protocols/relay.nim create mode 100644 tests/testrelay.nim diff --git a/libp2p/builders.nim b/libp2p/builders.nim index 8d97fedc3..7e3f88d0f 100644 --- a/libp2p/builders.nim +++ b/libp2p/builders.nim @@ -14,7 +14,7 @@ import switch, peerid, peerinfo, stream/connection, multiaddress, crypto/crypto, transports/[transport, tcptransport], muxers/[muxer, mplex/mplex], - protocols/[identify, secure/secure, secure/noise], + protocols/[identify, secure/secure, secure/noise, relay], connmanager, upgrademngrs/muxedupgrade, nameresolving/nameresolver, errors @@ -48,6 +48,8 @@ type protoVersion: string agentVersion: string nameResolver: NameResolver + isCircuitRelay: bool + circuitRelayCanHop: bool proc new*(T: type[SwitchBuilder]): T = @@ -64,7 +66,8 @@ proc new*(T: type[SwitchBuilder]): T = maxOut: -1, maxConnsPerPeer: MaxConnectionsPerPeer, protoVersion: ProtoVersion, - agentVersion: AgentVersion) + agentVersion: AgentVersion, + isCircuitRelay: false) proc withPrivateKey*(b: SwitchBuilder, privateKey: PrivateKey): SwitchBuilder = b.privKey = some(privateKey) @@ -139,6 +142,11 @@ proc withNameResolver*(b: SwitchBuilder, nameResolver: NameResolver): SwitchBuil b.nameResolver = nameResolver b +proc withRelayTransport*(b: SwitchBuilder, canHop: bool): SwitchBuilder = + b.isCircuitRelay = true + b.circuitRelayCanHop = canHop + b + proc build*(b: SwitchBuilder): Switch {.raises: [Defect, LPError].} = @@ -197,6 +205,11 @@ proc build*(b: SwitchBuilder): Switch ms = ms, nameResolver = b.nameResolver) + if b.isCircuitRelay: + let relay = Relay.new(switch, b.circuitRelayCanHop) + switch.mount(relay) + switch.addTransport(RelayTransport.new(relay, muxedUpgrade)) + return switch proc newStandardSwitch*( diff --git a/libp2p/dial.nim b/libp2p/dial.nim index ea51270a0..d850a9da2 100644 --- a/libp2p/dial.nim +++ b/libp2p/dial.nim @@ -11,7 +11,8 @@ import chronos import peerid, - stream/connection + stream/connection, + transports/transport type Dial* = ref object of RootObj @@ -49,3 +50,8 @@ method dial*( ## doAssert(false, "Not implemented!") + +method addTransport*( + self: Dial, + transport: Transport) {.base.} = + doAssert(false, "Not implemented!") diff --git a/libp2p/dialer.nim b/libp2p/dialer.nim index 65cc1d628..eb89164da 100644 --- a/libp2p/dialer.nim +++ b/libp2p/dialer.nim @@ -241,6 +241,9 @@ method dial*( await cleanup() raise exc +method addTransport*(self: Dialer, t: Transport) = + self.transports &= t + proc new*( T: type Dialer, localPeerId: PeerId, diff --git a/libp2p/multiaddress.nim b/libp2p/multiaddress.nim index cf9d6598e..7d7feedb3 100644 --- a/libp2p/multiaddress.nim +++ b/libp2p/multiaddress.nim @@ -428,7 +428,9 @@ const Reliable* = mapOr(TCP, UTP, QUIC, WebSockets) - IPFS* = mapAnd(Reliable, mapEq("p2p")) + P2PPattern* = mapEq("p2p") + + IPFS* = mapAnd(Reliable, P2PPattern) HTTP* = mapOr( mapAnd(TCP, mapEq("http")), @@ -447,6 +449,8 @@ const mapAnd(HTTPS, mapEq("p2p-webrtc-direct")) ) + CircuitRelay* = mapEq("p2p-circuit") + proc initMultiAddressCodeTable(): Table[MultiCodec, MAProtocol] {.compileTime.} = for item in ProtocolsList: diff --git a/libp2p/protocols/relay.nim b/libp2p/protocols/relay.nim new file mode 100644 index 000000000..acc042b13 --- /dev/null +++ b/libp2p/protocols/relay.nim @@ -0,0 +1,488 @@ +## Nim-LibP2P +## Copyright (c) 2022 Status Research & Development GmbH +## Licensed under either of +## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +## * MIT license ([LICENSE-MIT](LICENSE-MIT)) +## at your option. +## This file may not be copied, modified, or distributed except according to +## those terms. + +{.push raises: [Defect].} + +import options +import sequtils, strutils, tables +import chronos, chronicles + +import ../peerinfo, + ../switch, + ../multiaddress, + ../stream/connection, + ../protocols/protocol, + ../transports/transport, + ../utility, + ../errors + +const + RelayCodec* = "/libp2p/circuit/relay/0.1.0" + MsgSize* = 4096 + MaxCircuit* = 1024 + MaxCircuitPerPeer* = 64 + +logScope: + topics = "libp2p relay" + +type + RelayType* = enum + Hop = 1 + Stop = 2 + Status = 3 + CanHop = 4 + RelayStatus* = enum + Success = 100 + HopSrcAddrTooLong = 220 + HopDstAddrTooLong = 221 + HopSrcMultiaddrInvalid = 250 + HopDstMultiaddrInvalid = 251 + HopNoConnToDst = 260 + HopCantDialDst = 261 + HopCantOpenDstStream = 262 + HopCantSpeakRelay = 270 + HopCantRelayToSelf = 280 + StopSrcAddrTooLong = 320 + StopDstAddrTooLong = 321 + StopSrcMultiaddrInvalid = 350 + StopDstMultiaddrInvalid = 351 + StopRelayRefused = 390 + MalformedMessage = 400 + + RelayError* = object of LPError + + RelayPeer* = object + peerId*: PeerID + addrs*: seq[MultiAddress] + + AddConn* = proc(conn: Connection): Future[void] {.gcsafe, raises: [Defect].} + + RelayMessage* = object + msgType*: Option[RelayType] + srcPeer*: Option[RelayPeer] + dstPeer*: Option[RelayPeer] + status*: Option[RelayStatus] + + Relay* = ref object of LPProtocol + switch*: Switch + peerId: PeerID + dialer: Dial + canHop: bool + streamCount: int + hopCount: CountTable[PeerId] + + addConn: AddConn + + maxCircuit*: int + maxCircuitPerPeer*: int + msgSize*: int + +proc encodeMsg*(msg: RelayMessage): ProtoBuffer = + result = initProtoBuffer() + + if isSome(msg.msgType): + result.write(1, msg.msgType.get().ord.uint) + if isSome(msg.srcPeer): + var peer = initProtoBuffer() + peer.write(1, msg.srcPeer.get().peerId) + for ma in msg.srcPeer.get().addrs: + peer.write(2, ma.data.buffer) + peer.finish() + result.write(2, peer.buffer) + if isSome(msg.dstPeer): + var peer = initProtoBuffer() + peer.write(1, msg.dstPeer.get().peerId) + for ma in msg.dstPeer.get().addrs: + peer.write(2, ma.data.buffer) + peer.finish() + result.write(3, peer.buffer) + if isSome(msg.status): + result.write(4, msg.status.get().ord.uint) + + result.finish() + +proc decodeMsg*(buf: seq[byte]): Option[RelayMessage] = + var + rMsg: RelayMessage + msgTypeOrd: uint32 + src: RelayPeer + dst: RelayPeer + statusOrd: uint32 + pbSrc: ProtoBuffer + pbDst: ProtoBuffer + + let + pb = initProtoBuffer(buf) + r1 = pb.getField(1, msgTypeOrd) + r2 = pb.getField(2, pbSrc) + r3 = pb.getField(3, pbDst) + r4 = pb.getField(4, statusOrd) + + if r1.isErr() or r2.isErr() or r3.isErr() or r4.isErr(): + return none(RelayMessage) + + if r2.get() and + (pbSrc.getField(1, src.peerId).isErr() or + pbSrc.getRepeatedField(2, src.addrs).isErr()): + return none(RelayMessage) + + if r3.get() and + (pbDst.getField(1, dst.peerId).isErr() or + pbDst.getRepeatedField(2, dst.addrs).isErr()): + return none(RelayMessage) + + if r1.get(): rMsg.msgType = some(RelayType(msgTypeOrd)) + if r2.get(): rMsg.srcPeer = some(src) + if r3.get(): rMsg.dstPeer = some(dst) + if r4.get(): rMsg.status = some(RelayStatus(statusOrd)) + some(rMsg) + +proc sendStatus*(conn: Connection, code: RelayStatus) {.async, gcsafe.} = + trace "send status", status = $code & "(" & $ord(code) & ")" + let + msg = RelayMessage( + msgType: some(RelayType.Status), + status: some(code)) + pb = encodeMsg(msg) + + await conn.writeLp(pb.buffer) + +proc handleHopStream(r: Relay, conn: Connection, msg: RelayMessage) {.async, gcsafe.} = + r.streamCount.inc() + defer: + r.streamCount.dec() + + if r.streamCount > r.maxCircuit: + trace "refusing connection; too many active circuit" + await sendStatus(conn, RelayStatus.HopCantSpeakRelay) + return + + proc checkMsg(): Result[RelayMessage, RelayStatus] = + if not r.canHop: + return err(RelayStatus.HopCantSpeakRelay) + if msg.srcPeer.isNone: + return err(RelayStatus.HopSrcMultiaddrInvalid) + let src = msg.srcPeer.get() + if src.peerId != conn.peerId: + return err(RelayStatus.HopSrcMultiaddrInvalid) + if msg.dstPeer.isNone: + return err(RelayStatus.HopDstMultiaddrInvalid) + let dst = msg.dstPeer.get() + if dst.peerId == r.switch.peerInfo.peerId: + return err(RelayStatus.HopCantRelayToSelf) + if not r.switch.isConnected(dst.peerId): + trace "relay not connected to dst", dst + return err(RelayStatus.HopNoConnToDst) + ok(msg) + + let check = checkMsg() + if check.isErr: + await sendStatus(conn, check.error()) + return + let + src = msg.srcPeer.get() + dst = msg.dstPeer.get() + + # TODO: if r.acl # access control list + # and not r.acl.AllowHop(src.peerId, dst.peerId) + # sendStatus(conn, RelayStatus.HopCantSpeakRelay) + + r.hopCount.inc(src.peerId) + r.hopCount.inc(dst.peerId) + defer: + r.hopCount.inc(src.peerId, -1) + r.hopCount.inc(dst.peerId, -1) + + if r.hopCount[src.peerId] > r.maxCircuitPerPeer: + trace "refusing connection; too many connection from src", src, dst + await sendStatus(conn, RelayStatus.HopCantSpeakRelay) + return + + if r.hopCount[dst.peerId] > r.maxCircuitPerPeer: + trace "refusing connection; too many connection to dst", src, dst + await sendStatus(conn, RelayStatus.HopCantSpeakRelay) + return + + let connDst = try: + await r.switch.dial(dst.peerId, @[RelayCodec]) + except CancelledError as exc: + raise exc + except CatchableError as exc: + trace "error opening relay stream", dst, exc=exc.msg + await sendStatus(conn, RelayStatus.HopCantDialDst) + return + defer: + await connDst.close() + + let msgToSend = RelayMessage( + msgType: some(RelayType.Stop), + srcPeer: some(src), + dstPeer: some(dst), + status: none(RelayStatus)) + + let msgRcvFromDstOpt = try: + await connDst.writeLp(encodeMsg(msgToSend).buffer) + decodeMsg(await connDst.readLp(r.msgSize)) + except CancelledError as exc: + raise exc + except CatchableError as exc: + trace "error writing stop handshake or reading stop response", exc=exc.msg + await sendStatus(conn, RelayStatus.HopCantOpenDstStream) + return + + if msgRcvFromDstOpt.isNone: + trace "error reading stop response", msg = msgRcvFromDstOpt + await sendStatus(conn, RelayStatus.HopCantOpenDstStream) + return + + let msgRcvFromDst = msgRcvFromDstOpt.get() + if msgRcvFromDst.msgType.isNone or msgRcvFromDst.msgType.get() != RelayType.Status: + trace "unexcepted relay stop response", msgType = msgRcvFromDst.msgType + await sendStatus(conn, RelayStatus.HopCantOpenDstStream) + return + + if msgRcvFromDst.status.isNone or msgRcvFromDst.status.get() != RelayStatus.Success: + trace "relay stop failure", status=msgRcvFromDst.status + await sendStatus(conn, RelayStatus.HopCantOpenDstStream) + return + + await sendStatus(conn, RelayStatus.Success) + + trace "relaying connection", src, dst + + proc bridge(conn: Connection, connDst: Connection) {.async.} = + const bufferSize = 4096 + var + bufSrcToDst: array[bufferSize, byte] + bufDstToSrc: array[bufferSize, byte] + futSrc = conn.readOnce(addr bufSrcToDst[0], bufSrcToDst.high + 1) + futDst = connDst.readOnce(addr bufDstToSrc[0], bufDstToSrc.high + 1) + bytesSendFromSrcToDst = 0 + bytesSendFromDstToSrc = 0 + bufRead: int + + while not conn.closed() and not connDst.closed(): + try: + await futSrc or futDst + if futSrc.finished(): + bufRead = await futSrc + bytesSendFromSrcToDst.inc(bufRead) + await connDst.write(@bufSrcToDst[0..= 3: + result = CircuitRelay.match(sma[^2].get()) and + P2PPattern.match(sma[^1].get()) + trace "Handles return", ma, result + +proc new*(T: typedesc[RelayTransport], relay: Relay, upgrader: Upgrade): T = + result = T(relay: relay, upgrader: upgrader) + result.running = true + result.queue = newAsyncQueue[Connection](0) diff --git a/libp2p/switch.nim b/libp2p/switch.nim index 36125587c..affc0e5ef 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -86,6 +86,11 @@ proc removePeerEventHandler*(s: Switch, kind: PeerEventKind) = s.connManager.removePeerEventHandler(handler, kind) +method addTransport*(s: Switch, + t: Transport) = + s.transports &= t + s.dialer.addTransport(t) + proc isConnected*(s: Switch, peerId: PeerId): bool = ## returns true if the peer has one or more ## associated connections (sockets) @@ -248,7 +253,7 @@ proc start*(s: Switch) {.async, gcsafe.} = it notin addrs ) - if addrs.len > 0: + if addrs.len > 0 or t.running: startFuts.add(t.start(addrs)) await allFutures(startFuts) @@ -261,7 +266,7 @@ proc start*(s: Switch) {.async, gcsafe.} = "Failed to start one transport", s.error) for t in s.transports: # for each transport - if t.addrs.len > 0: + if t.addrs.len > 0 or t.running: s.acceptFuts.add(s.accept(t)) s.peerInfo.addrs &= t.addrs diff --git a/tests/testinterop.nim b/tests/testinterop.nim index 3ab86f6cd..cecace6d0 100644 --- a/tests/testinterop.nim +++ b/tests/testinterop.nim @@ -2,7 +2,7 @@ import options, tables import chronos, chronicles, stew/byteutils import helpers import ../libp2p -import ../libp2p/[daemon/daemonapi, varint, transports/wstransport, crypto/crypto] +import ../libp2p/[daemon/daemonapi, varint, transports/wstransport, crypto/crypto, protocols/relay ] type DaemonPeerInfo = daemonapi.PeerInfo @@ -471,3 +471,158 @@ suite "Interop": asyncTest "gossipsub: node publish many": await testPubSubNodePublish(gossip = true, count = 10) + + asyncTest "NativeSrc -> NativeRelay -> DaemonDst": + proc daemonHandler(api: DaemonAPI, stream: P2PStream) {.async.} = + check "line1" == string.fromBytes(await stream.transp.readLp()) + discard await stream.transp.writeLp("line2") + check "line3" == string.fromBytes(await stream.transp.readLp()) + discard await stream.transp.writeLp("line4") + await stream.close() + let + maSrc = MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet() + maRel = MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet() + src = SwitchBuilder.new() + .withRng(crypto.newRng()) + .withAddresses(@[ maSrc ]) + .withTcpTransport() + .withMplex() + .withNoise() + .withRelayTransport(false) + .build() + rel = SwitchBuilder.new() + .withRng(crypto.newRng()) + .withAddresses(@[ maRel ]) + .withTcpTransport() + .withMplex() + .withNoise() + .withRelayTransport(true) + .build() + + await src.start() + await rel.start() + let daemonNode = await newDaemonApi() + let daemonPeer = await daemonNode.identity() + let maStr = $rel.peerInfo.addrs[0] & "/p2p/" & $rel.peerInfo.peerId & "/p2p-circuit/p2p/" & $daemonPeer.peer + let maddr = MultiAddress.init(maStr).tryGet() + await src.connect(rel.peerInfo.peerId, rel.peerInfo.addrs) + await rel.connect(daemonPeer.peer, daemonPeer.addresses) + + await daemonNode.addHandler(@[ "/testCustom" ], daemonHandler) + + let conn = await src.dial(daemonPeer.peer, @[ maddr ], @[ "/testCustom" ]) + + await conn.writeLp("line1") + check string.fromBytes(await conn.readLp(1024)) == "line2" + + await conn.writeLp("line3") + check string.fromBytes(await conn.readLp(1024)) == "line4" + + await allFutures(src.stop(), rel.stop()) + await daemonNode.close() + + asyncTest "DaemonSrc -> NativeRelay -> NativeDst": + proc customHandler(conn: Connection, proto: string) {.async.} = + check "line1" == string.fromBytes(await conn.readLp(1024)) + await conn.writeLp("line2") + check "line3" == string.fromBytes(await conn.readLp(1024)) + await conn.writeLp("line4") + await conn.close() + let + protos = @[ "/customProto", RelayCodec ] + var + customProto = new LPProtocol + customProto.handler = customHandler + customProto.codec = protos[0] + let + maRel = MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet() + maDst = MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet() + rel = SwitchBuilder.new() + .withRng(crypto.newRng()) + .withAddresses(@[ maRel ]) + .withTcpTransport() + .withMplex() + .withNoise() + .withRelayTransport(true) + .build() + dst = SwitchBuilder.new() + .withRng(crypto.newRng()) + .withAddresses(@[ maDst ]) + .withTcpTransport() + .withMplex() + .withNoise() + .withRelayTransport(false) + .build() + + dst.mount(customProto) + await rel.start() + await dst.start() + let daemonNode = await newDaemonApi() + let daemonPeer = await daemonNode.identity() + let maStr = $rel.peerInfo.addrs[0] & "/p2p/" & $rel.peerInfo.peerId & "/p2p-circuit/p2p/" & $dst.peerInfo.peerId + let maddr = MultiAddress.init(maStr).tryGet() + await daemonNode.connect(rel.peerInfo.peerId, rel.peerInfo.addrs) + await rel.connect(dst.peerInfo.peerId, dst.peerInfo.addrs) + await daemonNode.connect(dst.peerInfo.peerId, @[ maddr ]) + var stream = await daemonNode.openStream(dst.peerInfo.peerId, protos) + + discard await stream.transp.writeLp("line1") + check string.fromBytes(await stream.transp.readLp()) == "line2" + discard await stream.transp.writeLp("line3") + check string.fromBytes(await stream.transp.readLp()) == "line4" + + await allFutures(dst.stop(), rel.stop()) + await daemonNode.close() + + asyncTest "NativeSrc -> DaemonRelay -> NativeDst": + proc customHandler(conn: Connection, proto: string) {.async.} = + check "line1" == string.fromBytes(await conn.readLp(1024)) + await conn.writeLp("line2") + check "line3" == string.fromBytes(await conn.readLp(1024)) + await conn.writeLp("line4") + await conn.close() + let + protos = @[ "/customProto", RelayCodec ] + var + customProto = new LPProtocol + customProto.handler = customHandler + customProto.codec = protos[0] + let + maSrc = MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet() + maDst = MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet() + src = SwitchBuilder.new() + .withRng(crypto.newRng()) + .withAddresses(@[ maSrc ]) + .withTcpTransport() + .withMplex() + .withNoise() + .withRelayTransport(false) + .build() + dst = SwitchBuilder.new() + .withRng(crypto.newRng()) + .withAddresses(@[ maDst ]) + .withTcpTransport() + .withMplex() + .withNoise() + .withRelayTransport(false) + .build() + + dst.mount(customProto) + await src.start() + await dst.start() + let daemonNode = await newDaemonApi({RelayHop}) + let daemonPeer = await daemonNode.identity() + let maStr = $daemonPeer.addresses[0] & "/p2p/" & $daemonPeer.peer & "/p2p-circuit/p2p/" & $dst.peerInfo.peerId + let maddr = MultiAddress.init(maStr).tryGet() + await src.connect(daemonPeer.peer, daemonPeer.addresses) + await daemonNode.connect(dst.peerInfo.peerId, dst.peerInfo.addrs) + let conn = await src.dial(dst.peerInfo.peerId, @[ maddr ], protos[0]) + + await conn.writeLp("line1") + check string.fromBytes(await conn.readLp(1024)) == "line2" + + await conn.writeLp("line3") + check string.fromBytes(await conn.readLp(1024)) == "line4" + + await allFutures(src.stop(), dst.stop()) + await daemonNode.close() diff --git a/tests/testnative.nim b/tests/testnative.nim index 1baf69abe..998f104e2 100644 --- a/tests/testnative.nim +++ b/tests/testnative.nim @@ -32,4 +32,5 @@ import testtcptransport, testpeerinfo, testpeerstore, testping, - testmplex + testmplex, + testrelay diff --git a/tests/testrelay.nim b/tests/testrelay.nim new file mode 100644 index 000000000..a9fa514e5 --- /dev/null +++ b/tests/testrelay.nim @@ -0,0 +1,350 @@ +{.used.} + +import options, bearssl, chronos +import stew/byteutils +import ../libp2p/[protocols/relay, + multiaddress, + peerinfo, + peerid, + stream/connection, + multistream, + transports/transport, + switch, + builders, + upgrademngrs/upgrade, + varint, + daemon/daemonapi] +import ./helpers + +proc new(T: typedesc[RelayTransport], relay: Relay): T = + T.new(relay = relay, upgrader = relay.switch.transports[0].upgrader) + +proc writeLp*(s: StreamTransport, msg: string | seq[byte]): Future[int] {.gcsafe.} = + ## write lenght prefixed + var buf = initVBuffer() + buf.writeSeq(msg) + buf.finish() + result = s.write(buf.buffer) + +proc readLp*(s: StreamTransport): Future[seq[byte]] {.async, gcsafe.} = + ## read length prefixed msg + var + size: uint + length: int + res: VarintResult[void] + result = newSeq[byte](10) + + for i in 0.. 0.uint: + await s.readExactly(addr result[0], int(size)) + +suite "Circuit Relay": + asyncTeardown: + await allFutures(src.stop(), dst.stop(), rel.stop()) + checkTrackers() + + var + protos {.threadvar.}: seq[string] + customProto {.threadvar.}: LPProtocol + ma {.threadvar.}: MultiAddress + src {.threadvar.}: Switch + dst {.threadvar.}: Switch + rel {.threadvar.}: Switch + relaySrc {.threadvar.}: Relay + relayDst {.threadvar.}: Relay + relayRel {.threadvar.}: Relay + conn {.threadVar.}: Connection + msg {.threadVar.}: ProtoBuffer + rcv {.threadVar.}: Option[RelayMessage] + + proc createMsg( + msgType: Option[RelayType] = RelayType.none, + status: Option[RelayStatus] = RelayStatus.none, + src: Option[RelayPeer] = RelayPeer.none, + dst: Option[RelayPeer] = RelayPeer.none): ProtoBuffer = + encodeMsg(RelayMessage(msgType: msgType, srcPeer: src, dstPeer: dst, status: status)) + + proc checkMsg(msg: Option[RelayMessage], + msgType: Option[RelayType] = none[RelayType](), + status: Option[RelayStatus] = none[RelayStatus](), + src: Option[RelayPeer] = none[RelayPeer](), + dst: Option[RelayPeer] = none[RelayPeer]()): bool = + msg.isSome and msg.get == RelayMessage(msgType: msgType, srcPeer: src, dstPeer: dst, status: status) + + proc customHandler(conn: Connection, proto: string) {.async.} = + check "line1" == string.fromBytes(await conn.readLp(1024)) + await conn.writeLp("line2") + check "line3" == string.fromBytes(await conn.readLp(1024)) + await conn.writeLp("line4") + await conn.close() + + asyncSetup: + # Create a custom prototype + protos = @[ "/customProto", RelayCodec ] + customProto = new LPProtocol + customProto.handler = customHandler + customProto.codec = protos[0] + ma = MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet() + + src = newStandardSwitch() + rel = newStandardSwitch() + dst = SwitchBuilder + .new() + .withRng(newRng()) + .withAddresses(@[ ma ]) + .withTcpTransport() + .withMplex() + .withNoise() + .build() + + relaySrc = Relay.new(src, false) + relayDst = Relay.new(dst, false) + relayRel = Relay.new(rel, true) + + src.mount(relaySrc) + dst.mount(relayDst) + dst.mount(customProto) + rel.mount(relayRel) + + src.addTransport(RelayTransport.new(relaySrc)) + dst.addTransport(RelayTransport.new(relayDst)) + + await src.start() + await dst.start() + await rel.start() + + asyncTest "Handle CanHop": + msg = createMsg(some(CanHop)) + conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec) + await conn.writeLp(msg.buffer) + rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize)) + check rcv.checkMsg(some(Status), some(RelayStatus.Success)) + + conn = await src.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayCodec) + await conn.writeLp(msg.buffer) + rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize)) + check rcv.checkMsg(some(Status), some(HopCantSpeakRelay)) + + await conn.close() + + asyncTest "Malformed": + conn = await rel.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayCodec) + msg = createMsg(some(RelayType.Status)) + await conn.writeLp(msg.buffer) + rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize)) + await conn.close() + check rcv.checkMsg(some(Status), some(MalformedMessage)) + + asyncTest "Handle Stop Error": + conn = await rel.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayCodec) + msg = createMsg(some(RelayType.Stop), + none(RelayStatus), + none(RelayPeer), + some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs))) + await conn.writeLp(msg.buffer) + rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize)) + check rcv.checkMsg(some(Status), some(StopSrcMultiaddrInvalid)) + + conn = await rel.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayCodec) + msg = createMsg(some(RelayType.Stop), + none(RelayStatus), + some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)), + none(RelayPeer)) + await conn.writeLp(msg.buffer) + rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize)) + check rcv.checkMsg(some(Status), some(StopDstMultiaddrInvalid)) + + conn = await rel.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayCodec) + msg = createMsg(some(RelayType.Stop), + none(RelayStatus), + some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs)), + some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs))) + await conn.writeLp(msg.buffer) + rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize)) + await conn.close() + check rcv.checkMsg(some(Status), some(StopDstMultiaddrInvalid)) + + asyncTest "Handle Hop Error": + conn = await src.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayCodec) + msg = createMsg(some(RelayType.Hop)) + await conn.writeLp(msg.buffer) + rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize)) + check rcv.checkMsg(some(Status), some(HopCantSpeakRelay)) + + conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec) + msg = createMsg(some(RelayType.Hop), + none(RelayStatus), + none(RelayPeer), + some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs))) + await conn.writeLp(msg.buffer) + rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize)) + check rcv.checkMsg(some(Status), some(HopSrcMultiaddrInvalid)) + + conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec) + msg = createMsg(some(RelayType.Hop), + none(RelayStatus), + some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs)), + some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs))) + await conn.writeLp(msg.buffer) + rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize)) + check rcv.checkMsg(some(Status), some(HopSrcMultiaddrInvalid)) + + conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec) + msg = createMsg(some(RelayType.Hop), + none(RelayStatus), + some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)), + none(RelayPeer)) + await conn.writeLp(msg.buffer) + rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize)) + check rcv.checkMsg(some(Status), some(HopDstMultiaddrInvalid)) + + conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec) + msg = createMsg(some(RelayType.Hop), + none(RelayStatus), + some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)), + some(RelayPeer(peerId: rel.peerInfo.peerId, addrs: rel.peerInfo.addrs))) + await conn.writeLp(msg.buffer) + rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize)) + check rcv.checkMsg(some(Status), some(HopCantRelayToSelf)) + + conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec) + msg = createMsg(some(RelayType.Hop), + none(RelayStatus), + some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)), + some(RelayPeer(peerId: rel.peerInfo.peerId, addrs: rel.peerInfo.addrs))) + await conn.writeLp(msg.buffer) + rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize)) + check rcv.checkMsg(some(Status), some(HopCantRelayToSelf)) + + conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec) + msg = createMsg(some(RelayType.Hop), + none(RelayStatus), + some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)), + some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs))) + await conn.writeLp(msg.buffer) + rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize)) + check rcv.checkMsg(some(Status), some(HopNoConnToDst)) + + await rel.connect(dst.peerInfo.peerId, dst.peerInfo.addrs) + + relayRel.maxCircuit = 0 + conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec) + await conn.writeLp(msg.buffer) + rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize)) + check rcv.checkMsg(some(Status), some(HopCantSpeakRelay)) + relayRel.maxCircuit = relay.MaxCircuit + await conn.close() + + relayRel.maxCircuitPerPeer = 0 + conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec) + await conn.writeLp(msg.buffer) + rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize)) + check rcv.checkMsg(some(Status), some(HopCantSpeakRelay)) + relayRel.maxCircuitPerPeer = relay.MaxCircuitPerPeer + await conn.close() + + let dst2 = newStandardSwitch() + await dst2.start() + await rel.connect(dst2.peerInfo.peerId, dst2.peerInfo.addrs) + + conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec) + msg = createMsg(some(RelayType.Hop), + none(RelayStatus), + some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)), + some(RelayPeer(peerId: dst2.peerInfo.peerId, addrs: dst2.peerInfo.addrs))) + await conn.writeLp(msg.buffer) + rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize)) + check rcv.checkMsg(some(Status), some(HopCantDialDst)) + await allFutures(dst2.stop()) + + asyncTest "Dial Peer": + let maStr = $rel.peerInfo.addrs[0] & "/p2p/" & $rel.peerInfo.peerId & "/p2p-circuit/p2p/" & $dst.peerInfo.peerId + let maddr = MultiAddress.init(maStr).tryGet() + await src.connect(rel.peerInfo.peerId, rel.peerInfo.addrs) + await rel.connect(dst.peerInfo.peerId, dst.peerInfo.addrs) + conn = await src.dial(dst.peerInfo.peerId, @[ maddr ], protos[0]) + + await conn.writeLp("line1") + check string.fromBytes(await conn.readLp(1024)) == "line2" + + await conn.writeLp("line3") + check string.fromBytes(await conn.readLp(1024)) == "line4" + + asyncTest "SwitchBuilder withRelay": + let + maSrc = MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet() + maRel = MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet() + maDst = MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet() + srcWR = SwitchBuilder.new() + .withRng(newRng()) + .withAddresses(@[ maSrc ]) + .withTcpTransport() + .withMplex() + .withNoise() + .withRelayTransport(false) + .build() + relWR = SwitchBuilder.new() + .withRng(newRng()) + .withAddresses(@[ maRel ]) + .withTcpTransport() + .withMplex() + .withNoise() + .withRelayTransport(true) + .build() + dstWR = SwitchBuilder.new() + .withRng(newRng()) + .withAddresses(@[ maDst ]) + .withTcpTransport() + .withMplex() + .withNoise() + .withRelayTransport(false) + .build() + + dstWR.mount(customProto) + + await srcWR.start() + await dstWR.start() + await relWR.start() + + let maStr = $relWR.peerInfo.addrs[0] & "/p2p/" & $relWR.peerInfo.peerId & "/p2p-circuit/p2p/" & $dstWR.peerInfo.peerId + let maddr = MultiAddress.init(maStr).tryGet() + await srcWR.connect(relWR.peerInfo.peerId, relWR.peerInfo.addrs) + await relWR.connect(dstWR.peerInfo.peerId, dstWR.peerInfo.addrs) + conn = await srcWR.dial(dstWR.peerInfo.peerId, @[ maddr ], protos[0]) + + await conn.writeLp("line1") + check string.fromBytes(await conn.readLp(1024)) == "line2" + + await conn.writeLp("line3") + check string.fromBytes(await conn.readLp(1024)) == "line4" + + await allFutures(srcWR.stop(), dstWR.stop(), relWR.stop()) + + asynctest "Bad MultiAddress": + await src.connect(rel.peerInfo.peerId, rel.peerInfo.addrs) + await rel.connect(dst.peerInfo.peerId, dst.peerInfo.addrs) + expect(CatchableError): + let maStr = $rel.peerInfo.addrs[0] & "/p2p/" & $rel.peerInfo.peerId & "/p2p/" & $dst.peerInfo.peerId + let maddr = MultiAddress.init(maStr).tryGet() + conn = await src.dial(dst.peerInfo.peerId, @[ maddr ], protos[0]) + + expect(CatchableError): + let maStr = $rel.peerInfo.addrs[0] & "/p2p/" & $rel.peerInfo.peerId + let maddr = MultiAddress.init(maStr).tryGet() + conn = await src.dial(dst.peerInfo.peerId, @[ maddr ], protos[0]) + + expect(CatchableError): + let maStr = "/ip4/127.0.0.1" + let maddr = MultiAddress.init(maStr).tryGet() + conn = await src.dial(dst.peerInfo.peerId, @[ maddr ], protos[0]) + + expect(CatchableError): + let maStr = $dst.peerInfo.peerId + let maddr = MultiAddress.init(maStr).tryGet() + conn = await src.dial(dst.peerInfo.peerId, @[ maddr ], protos[0])