Circuit relay v1 (#670)

Circuit relay v1
This commit is contained in:
lchenut 2022-05-18 10:19:37 +02:00 committed by GitHub
parent 991549f391
commit 13503f3799
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 1033 additions and 8 deletions

View File

@ -14,7 +14,7 @@ import
switch, peerid, peerinfo, stream/connection, multiaddress,
crypto/crypto, transports/[transport, tcptransport],
muxers/[muxer, mplex/mplex],
protocols/[identify, secure/secure, secure/noise],
protocols/[identify, secure/secure, secure/noise, relay],
connmanager, upgrademngrs/muxedupgrade,
nameresolving/nameresolver,
errors
@ -48,6 +48,8 @@ type
protoVersion: string
agentVersion: string
nameResolver: NameResolver
isCircuitRelay: bool
circuitRelayCanHop: bool
proc new*(T: type[SwitchBuilder]): T =
@ -64,7 +66,8 @@ proc new*(T: type[SwitchBuilder]): T =
maxOut: -1,
maxConnsPerPeer: MaxConnectionsPerPeer,
protoVersion: ProtoVersion,
agentVersion: AgentVersion)
agentVersion: AgentVersion,
isCircuitRelay: false)
proc withPrivateKey*(b: SwitchBuilder, privateKey: PrivateKey): SwitchBuilder =
b.privKey = some(privateKey)
@ -139,6 +142,11 @@ proc withNameResolver*(b: SwitchBuilder, nameResolver: NameResolver): SwitchBuil
b.nameResolver = nameResolver
b
proc withRelayTransport*(b: SwitchBuilder, canHop: bool): SwitchBuilder =
b.isCircuitRelay = true
b.circuitRelayCanHop = canHop
b
proc build*(b: SwitchBuilder): Switch
{.raises: [Defect, LPError].} =
@ -197,6 +205,11 @@ proc build*(b: SwitchBuilder): Switch
ms = ms,
nameResolver = b.nameResolver)
if b.isCircuitRelay:
let relay = Relay.new(switch, b.circuitRelayCanHop)
switch.mount(relay)
switch.addTransport(RelayTransport.new(relay, muxedUpgrade))
return switch
proc newStandardSwitch*(

View File

@ -11,7 +11,8 @@
import chronos
import peerid,
stream/connection
stream/connection,
transports/transport
type
Dial* = ref object of RootObj
@ -49,3 +50,8 @@ method dial*(
##
doAssert(false, "Not implemented!")
method addTransport*(
self: Dial,
transport: Transport) {.base.} =
doAssert(false, "Not implemented!")

View File

@ -241,6 +241,9 @@ method dial*(
await cleanup()
raise exc
method addTransport*(self: Dialer, t: Transport) =
self.transports &= t
proc new*(
T: type Dialer,
localPeerId: PeerId,

View File

@ -428,7 +428,9 @@ const
Reliable* = mapOr(TCP, UTP, QUIC, WebSockets)
IPFS* = mapAnd(Reliable, mapEq("p2p"))
P2PPattern* = mapEq("p2p")
IPFS* = mapAnd(Reliable, P2PPattern)
HTTP* = mapOr(
mapAnd(TCP, mapEq("http")),
@ -447,6 +449,8 @@ const
mapAnd(HTTPS, mapEq("p2p-webrtc-direct"))
)
CircuitRelay* = mapEq("p2p-circuit")
proc initMultiAddressCodeTable(): Table[MultiCodec,
MAProtocol] {.compileTime.} =
for item in ProtocolsList:

488
libp2p/protocols/relay.nim Normal file
View File

@ -0,0 +1,488 @@
## Nim-LibP2P
## Copyright (c) 2022 Status Research & Development GmbH
## Licensed under either of
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
## at your option.
## This file may not be copied, modified, or distributed except according to
## those terms.
{.push raises: [Defect].}
import options
import sequtils, strutils, tables
import chronos, chronicles
import ../peerinfo,
../switch,
../multiaddress,
../stream/connection,
../protocols/protocol,
../transports/transport,
../utility,
../errors
const
RelayCodec* = "/libp2p/circuit/relay/0.1.0"
MsgSize* = 4096
MaxCircuit* = 1024
MaxCircuitPerPeer* = 64
logScope:
topics = "libp2p relay"
type
RelayType* = enum
Hop = 1
Stop = 2
Status = 3
CanHop = 4
RelayStatus* = enum
Success = 100
HopSrcAddrTooLong = 220
HopDstAddrTooLong = 221
HopSrcMultiaddrInvalid = 250
HopDstMultiaddrInvalid = 251
HopNoConnToDst = 260
HopCantDialDst = 261
HopCantOpenDstStream = 262
HopCantSpeakRelay = 270
HopCantRelayToSelf = 280
StopSrcAddrTooLong = 320
StopDstAddrTooLong = 321
StopSrcMultiaddrInvalid = 350
StopDstMultiaddrInvalid = 351
StopRelayRefused = 390
MalformedMessage = 400
RelayError* = object of LPError
RelayPeer* = object
peerId*: PeerID
addrs*: seq[MultiAddress]
AddConn* = proc(conn: Connection): Future[void] {.gcsafe, raises: [Defect].}
RelayMessage* = object
msgType*: Option[RelayType]
srcPeer*: Option[RelayPeer]
dstPeer*: Option[RelayPeer]
status*: Option[RelayStatus]
Relay* = ref object of LPProtocol
switch*: Switch
peerId: PeerID
dialer: Dial
canHop: bool
streamCount: int
hopCount: CountTable[PeerId]
addConn: AddConn
maxCircuit*: int
maxCircuitPerPeer*: int
msgSize*: int
proc encodeMsg*(msg: RelayMessage): ProtoBuffer =
result = initProtoBuffer()
if isSome(msg.msgType):
result.write(1, msg.msgType.get().ord.uint)
if isSome(msg.srcPeer):
var peer = initProtoBuffer()
peer.write(1, msg.srcPeer.get().peerId)
for ma in msg.srcPeer.get().addrs:
peer.write(2, ma.data.buffer)
peer.finish()
result.write(2, peer.buffer)
if isSome(msg.dstPeer):
var peer = initProtoBuffer()
peer.write(1, msg.dstPeer.get().peerId)
for ma in msg.dstPeer.get().addrs:
peer.write(2, ma.data.buffer)
peer.finish()
result.write(3, peer.buffer)
if isSome(msg.status):
result.write(4, msg.status.get().ord.uint)
result.finish()
proc decodeMsg*(buf: seq[byte]): Option[RelayMessage] =
var
rMsg: RelayMessage
msgTypeOrd: uint32
src: RelayPeer
dst: RelayPeer
statusOrd: uint32
pbSrc: ProtoBuffer
pbDst: ProtoBuffer
let
pb = initProtoBuffer(buf)
r1 = pb.getField(1, msgTypeOrd)
r2 = pb.getField(2, pbSrc)
r3 = pb.getField(3, pbDst)
r4 = pb.getField(4, statusOrd)
if r1.isErr() or r2.isErr() or r3.isErr() or r4.isErr():
return none(RelayMessage)
if r2.get() and
(pbSrc.getField(1, src.peerId).isErr() or
pbSrc.getRepeatedField(2, src.addrs).isErr()):
return none(RelayMessage)
if r3.get() and
(pbDst.getField(1, dst.peerId).isErr() or
pbDst.getRepeatedField(2, dst.addrs).isErr()):
return none(RelayMessage)
if r1.get(): rMsg.msgType = some(RelayType(msgTypeOrd))
if r2.get(): rMsg.srcPeer = some(src)
if r3.get(): rMsg.dstPeer = some(dst)
if r4.get(): rMsg.status = some(RelayStatus(statusOrd))
some(rMsg)
proc sendStatus*(conn: Connection, code: RelayStatus) {.async, gcsafe.} =
trace "send status", status = $code & "(" & $ord(code) & ")"
let
msg = RelayMessage(
msgType: some(RelayType.Status),
status: some(code))
pb = encodeMsg(msg)
await conn.writeLp(pb.buffer)
proc handleHopStream(r: Relay, conn: Connection, msg: RelayMessage) {.async, gcsafe.} =
r.streamCount.inc()
defer:
r.streamCount.dec()
if r.streamCount > r.maxCircuit:
trace "refusing connection; too many active circuit"
await sendStatus(conn, RelayStatus.HopCantSpeakRelay)
return
proc checkMsg(): Result[RelayMessage, RelayStatus] =
if not r.canHop:
return err(RelayStatus.HopCantSpeakRelay)
if msg.srcPeer.isNone:
return err(RelayStatus.HopSrcMultiaddrInvalid)
let src = msg.srcPeer.get()
if src.peerId != conn.peerId:
return err(RelayStatus.HopSrcMultiaddrInvalid)
if msg.dstPeer.isNone:
return err(RelayStatus.HopDstMultiaddrInvalid)
let dst = msg.dstPeer.get()
if dst.peerId == r.switch.peerInfo.peerId:
return err(RelayStatus.HopCantRelayToSelf)
if not r.switch.isConnected(dst.peerId):
trace "relay not connected to dst", dst
return err(RelayStatus.HopNoConnToDst)
ok(msg)
let check = checkMsg()
if check.isErr:
await sendStatus(conn, check.error())
return
let
src = msg.srcPeer.get()
dst = msg.dstPeer.get()
# TODO: if r.acl # access control list
# and not r.acl.AllowHop(src.peerId, dst.peerId)
# sendStatus(conn, RelayStatus.HopCantSpeakRelay)
r.hopCount.inc(src.peerId)
r.hopCount.inc(dst.peerId)
defer:
r.hopCount.inc(src.peerId, -1)
r.hopCount.inc(dst.peerId, -1)
if r.hopCount[src.peerId] > r.maxCircuitPerPeer:
trace "refusing connection; too many connection from src", src, dst
await sendStatus(conn, RelayStatus.HopCantSpeakRelay)
return
if r.hopCount[dst.peerId] > r.maxCircuitPerPeer:
trace "refusing connection; too many connection to dst", src, dst
await sendStatus(conn, RelayStatus.HopCantSpeakRelay)
return
let connDst = try:
await r.switch.dial(dst.peerId, @[RelayCodec])
except CancelledError as exc:
raise exc
except CatchableError as exc:
trace "error opening relay stream", dst, exc=exc.msg
await sendStatus(conn, RelayStatus.HopCantDialDst)
return
defer:
await connDst.close()
let msgToSend = RelayMessage(
msgType: some(RelayType.Stop),
srcPeer: some(src),
dstPeer: some(dst),
status: none(RelayStatus))
let msgRcvFromDstOpt = try:
await connDst.writeLp(encodeMsg(msgToSend).buffer)
decodeMsg(await connDst.readLp(r.msgSize))
except CancelledError as exc:
raise exc
except CatchableError as exc:
trace "error writing stop handshake or reading stop response", exc=exc.msg
await sendStatus(conn, RelayStatus.HopCantOpenDstStream)
return
if msgRcvFromDstOpt.isNone:
trace "error reading stop response", msg = msgRcvFromDstOpt
await sendStatus(conn, RelayStatus.HopCantOpenDstStream)
return
let msgRcvFromDst = msgRcvFromDstOpt.get()
if msgRcvFromDst.msgType.isNone or msgRcvFromDst.msgType.get() != RelayType.Status:
trace "unexcepted relay stop response", msgType = msgRcvFromDst.msgType
await sendStatus(conn, RelayStatus.HopCantOpenDstStream)
return
if msgRcvFromDst.status.isNone or msgRcvFromDst.status.get() != RelayStatus.Success:
trace "relay stop failure", status=msgRcvFromDst.status
await sendStatus(conn, RelayStatus.HopCantOpenDstStream)
return
await sendStatus(conn, RelayStatus.Success)
trace "relaying connection", src, dst
proc bridge(conn: Connection, connDst: Connection) {.async.} =
const bufferSize = 4096
var
bufSrcToDst: array[bufferSize, byte]
bufDstToSrc: array[bufferSize, byte]
futSrc = conn.readOnce(addr bufSrcToDst[0], bufSrcToDst.high + 1)
futDst = connDst.readOnce(addr bufDstToSrc[0], bufDstToSrc.high + 1)
bytesSendFromSrcToDst = 0
bytesSendFromDstToSrc = 0
bufRead: int
while not conn.closed() and not connDst.closed():
try:
await futSrc or futDst
if futSrc.finished():
bufRead = await futSrc
bytesSendFromSrcToDst.inc(bufRead)
await connDst.write(@bufSrcToDst[0..<bufRead])
zeroMem(addr(bufSrcToDst), bufSrcToDst.high + 1)
futSrc = conn.readOnce(addr bufSrcToDst[0], bufSrcToDst.high + 1)
if futDst.finished():
bufRead = await futDst
bytesSendFromDstToSrc += bufRead
await conn.write(bufDstToSrc[0..<bufRead])
zeroMem(addr(bufDstToSrc), bufDstToSrc.high + 1)
futDst = connDst.readOnce(addr bufDstToSrc[0], bufDstToSrc.high + 1)
except CancelledError as exc:
raise exc
except CatchableError as exc:
if conn.closed() or conn.atEof():
trace "relay src closed connection", src
if connDst.closed() or connDst.atEof():
trace "relay dst closed connection", dst
trace "relay error", exc=exc.msg
break
trace "end relaying", bytesSendFromSrcToDst, bytesSendFromDstToSrc
await futSrc.cancelAndWait()
await futDst.cancelAndWait()
await bridge(conn, connDst)
proc handleStopStream(r: Relay, conn: Connection, msg: RelayMessage) {.async, gcsafe.} =
if msg.srcPeer.isNone:
await sendStatus(conn, RelayStatus.StopSrcMultiaddrInvalid)
return
let src = msg.srcPeer.get()
if msg.dstPeer.isNone:
await sendStatus(conn, RelayStatus.StopDstMultiaddrInvalid)
return
let dst = msg.dstPeer.get()
if dst.peerId != r.switch.peerInfo.peerId:
await sendStatus(conn, RelayStatus.StopDstMultiaddrInvalid)
return
trace "get a relay connection", src, conn
if r.addConn == nil:
await sendStatus(conn, RelayStatus.StopRelayRefused)
await conn.close()
return
await sendStatus(conn, RelayStatus.Success)
# This sound redundant but the callback could, in theory, be set to nil during
# sendStatus(Success) so it's safer to double check
if r.addConn != nil: await r.addConn(conn)
else: await conn.close()
proc handleCanHop(r: Relay, conn: Connection, msg: RelayMessage) {.async, gcsafe.} =
await sendStatus(conn,
if r.canHop:
RelayStatus.Success
else:
RelayStatus.HopCantSpeakRelay
)
proc new*(T: typedesc[Relay], switch: Switch, canHop: bool): T =
let relay = T(switch: switch, canHop: canHop)
relay.init()
relay
method init*(r: Relay) =
proc handleStream(conn: Connection, proto: string) {.async, gcsafe.} =
try:
let msgOpt = decodeMsg(await conn.readLp(r.msgSize))
if msgOpt.isNone:
await sendStatus(conn, RelayStatus.MalformedMessage)
return
else:
trace "relay handle stream", msg = msgOpt.get()
let msg = msgOpt.get()
case msg.msgType.get:
of RelayType.Hop: await r.handleHopStream(conn, msg)
of RelayType.Stop: await r.handleStopStream(conn, msg)
of RelayType.CanHop: await r.handleCanHop(conn, msg)
else:
trace "Unexpected relay handshake", msgType=msg.msgType
await sendStatus(conn, RelayStatus.MalformedMessage)
except CancelledError as exc:
raise exc
except CatchableError as exc:
trace "exception in relay handler", exc = exc.msg, conn
finally:
trace "exiting relay handler", conn
await conn.close()
r.handler = handleStream
r.codecs = @[RelayCodec]
r.maxCircuit = MaxCircuit
r.maxCircuitPerPeer = MaxCircuitPerPeer
r.msgSize = MsgSize
proc dialPeer(
r: Relay,
conn: Connection,
dstPeerId: PeerId,
dstAddrs: seq[MultiAddress]): Future[Connection] {.async.} =
var
msg = RelayMessage(
msgType: some(RelayType.Hop),
srcPeer: some(RelayPeer(peerId: r.switch.peerInfo.peerId, addrs: r.switch.peerInfo.addrs)),
dstPeer: some(RelayPeer(peerId: dstPeerId, addrs: dstAddrs)),
status: none(RelayStatus))
pb = encodeMsg(msg)
trace "Dial peer", msgSend=msg
try:
await conn.writeLp(pb.buffer)
except CancelledError as exc:
raise exc
except CatchableError as exc:
trace "error writing hop request", exc=exc.msg
raise exc
let msgRcvFromRelayOpt = try:
decodeMsg(await conn.readLp(r.msgSize))
except CancelledError as exc:
raise exc
except CatchableError as exc:
trace "error reading stop response", exc=exc.msg
await sendStatus(conn, RelayStatus.HopCantOpenDstStream)
raise exc
if msgRcvFromRelayOpt.isNone:
trace "error reading stop response", msg = msgRcvFromRelayOpt
await sendStatus(conn, RelayStatus.HopCantOpenDstStream)
raise newException(RelayError, "Hop can't open destination stream")
let msgRcvFromRelay = msgRcvFromRelayOpt.get()
if msgRcvFromRelay.msgType.isNone or msgRcvFromRelay.msgType.get() != RelayType.Status:
trace "unexcepted relay stop response", msgType = msgRcvFromRelay.msgType
await sendStatus(conn, RelayStatus.HopCantOpenDstStream)
raise newException(RelayError, "Hop can't open destination stream")
if msgRcvFromRelay.status.isNone or msgRcvFromRelay.status.get() != RelayStatus.Success:
trace "relay stop failure", status=msgRcvFromRelay.status
await sendStatus(conn, RelayStatus.HopCantOpenDstStream)
raise newException(RelayError, "Hop can't open destination stream")
result = conn
#
# Relay Transport
#
type
RelayTransport* = ref object of Transport
relay*: Relay
queue: AsyncQueue[Connection]
relayRunning: bool
method start*(self: RelayTransport, ma: seq[MultiAddress]) {.async.} =
if self.relayRunning:
trace "Relay transport already running"
return
await procCall Transport(self).start(ma)
self.relayRunning = true
self.relay.addConn = proc(conn: Connection) {.async, gcsafe, raises: [Defect].} =
await self.queue.addLast(conn)
await conn.join()
trace "Starting Relay transport"
method stop*(self: RelayTransport) {.async, gcsafe.} =
self.running = false
self.relayRunning = false
self.relay.addConn = nil
while not self.queue.empty():
await self.queue.popFirstNoWait().close()
method accept*(self: RelayTransport): Future[Connection] {.async, gcsafe.} =
result = await self.queue.popFirst()
proc dial*(self: RelayTransport, ma: MultiAddress): Future[Connection] {.async, gcsafe.} =
let
sma = toSeq(ma.items())
relayAddrs = sma[0..sma.len-4].mapIt(it.tryGet()).foldl(a & b)
var
relayPeerId: PeerId
dstPeerId: PeerId
if not relayPeerId.init(($(sma[^3].get())).split('/')[2]):
raise newException(RelayError, "Relay doesn't exist")
if not dstPeerId.init(($(sma[^1].get())).split('/')[2]):
raise newException(RelayError, "Destination doesn't exist")
trace "Dial", relayPeerId, relayAddrs, dstPeerId
let conn = await self.relay.switch.dial(relayPeerId, @[ relayAddrs ], RelayCodec)
result = await self.relay.dialPeer(conn, dstPeerId, @[])
method dial*(
self: RelayTransport,
hostname: string,
address: MultiAddress): Future[Connection] {.async, gcsafe.} =
result = await self.dial(address)
method handles*(self: RelayTransport, ma: MultiAddress): bool {.gcsafe} =
if ma.protocols.isOk:
let sma = toSeq(ma.items())
if sma.len >= 3:
result = CircuitRelay.match(sma[^2].get()) and
P2PPattern.match(sma[^1].get())
trace "Handles return", ma, result
proc new*(T: typedesc[RelayTransport], relay: Relay, upgrader: Upgrade): T =
result = T(relay: relay, upgrader: upgrader)
result.running = true
result.queue = newAsyncQueue[Connection](0)

View File

@ -86,6 +86,11 @@ proc removePeerEventHandler*(s: Switch,
kind: PeerEventKind) =
s.connManager.removePeerEventHandler(handler, kind)
method addTransport*(s: Switch,
t: Transport) =
s.transports &= t
s.dialer.addTransport(t)
proc isConnected*(s: Switch, peerId: PeerId): bool =
## returns true if the peer has one or more
## associated connections (sockets)
@ -248,7 +253,7 @@ proc start*(s: Switch) {.async, gcsafe.} =
it notin addrs
)
if addrs.len > 0:
if addrs.len > 0 or t.running:
startFuts.add(t.start(addrs))
await allFutures(startFuts)
@ -261,7 +266,7 @@ proc start*(s: Switch) {.async, gcsafe.} =
"Failed to start one transport", s.error)
for t in s.transports: # for each transport
if t.addrs.len > 0:
if t.addrs.len > 0 or t.running:
s.acceptFuts.add(s.accept(t))
s.peerInfo.addrs &= t.addrs

View File

@ -2,7 +2,7 @@ import options, tables
import chronos, chronicles, stew/byteutils
import helpers
import ../libp2p
import ../libp2p/[daemon/daemonapi, varint, transports/wstransport, crypto/crypto]
import ../libp2p/[daemon/daemonapi, varint, transports/wstransport, crypto/crypto, protocols/relay ]
type
DaemonPeerInfo = daemonapi.PeerInfo
@ -471,3 +471,158 @@ suite "Interop":
asyncTest "gossipsub: node publish many":
await testPubSubNodePublish(gossip = true, count = 10)
asyncTest "NativeSrc -> NativeRelay -> DaemonDst":
proc daemonHandler(api: DaemonAPI, stream: P2PStream) {.async.} =
check "line1" == string.fromBytes(await stream.transp.readLp())
discard await stream.transp.writeLp("line2")
check "line3" == string.fromBytes(await stream.transp.readLp())
discard await stream.transp.writeLp("line4")
await stream.close()
let
maSrc = MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()
maRel = MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()
src = SwitchBuilder.new()
.withRng(crypto.newRng())
.withAddresses(@[ maSrc ])
.withTcpTransport()
.withMplex()
.withNoise()
.withRelayTransport(false)
.build()
rel = SwitchBuilder.new()
.withRng(crypto.newRng())
.withAddresses(@[ maRel ])
.withTcpTransport()
.withMplex()
.withNoise()
.withRelayTransport(true)
.build()
await src.start()
await rel.start()
let daemonNode = await newDaemonApi()
let daemonPeer = await daemonNode.identity()
let maStr = $rel.peerInfo.addrs[0] & "/p2p/" & $rel.peerInfo.peerId & "/p2p-circuit/p2p/" & $daemonPeer.peer
let maddr = MultiAddress.init(maStr).tryGet()
await src.connect(rel.peerInfo.peerId, rel.peerInfo.addrs)
await rel.connect(daemonPeer.peer, daemonPeer.addresses)
await daemonNode.addHandler(@[ "/testCustom" ], daemonHandler)
let conn = await src.dial(daemonPeer.peer, @[ maddr ], @[ "/testCustom" ])
await conn.writeLp("line1")
check string.fromBytes(await conn.readLp(1024)) == "line2"
await conn.writeLp("line3")
check string.fromBytes(await conn.readLp(1024)) == "line4"
await allFutures(src.stop(), rel.stop())
await daemonNode.close()
asyncTest "DaemonSrc -> NativeRelay -> NativeDst":
proc customHandler(conn: Connection, proto: string) {.async.} =
check "line1" == string.fromBytes(await conn.readLp(1024))
await conn.writeLp("line2")
check "line3" == string.fromBytes(await conn.readLp(1024))
await conn.writeLp("line4")
await conn.close()
let
protos = @[ "/customProto", RelayCodec ]
var
customProto = new LPProtocol
customProto.handler = customHandler
customProto.codec = protos[0]
let
maRel = MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()
maDst = MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()
rel = SwitchBuilder.new()
.withRng(crypto.newRng())
.withAddresses(@[ maRel ])
.withTcpTransport()
.withMplex()
.withNoise()
.withRelayTransport(true)
.build()
dst = SwitchBuilder.new()
.withRng(crypto.newRng())
.withAddresses(@[ maDst ])
.withTcpTransport()
.withMplex()
.withNoise()
.withRelayTransport(false)
.build()
dst.mount(customProto)
await rel.start()
await dst.start()
let daemonNode = await newDaemonApi()
let daemonPeer = await daemonNode.identity()
let maStr = $rel.peerInfo.addrs[0] & "/p2p/" & $rel.peerInfo.peerId & "/p2p-circuit/p2p/" & $dst.peerInfo.peerId
let maddr = MultiAddress.init(maStr).tryGet()
await daemonNode.connect(rel.peerInfo.peerId, rel.peerInfo.addrs)
await rel.connect(dst.peerInfo.peerId, dst.peerInfo.addrs)
await daemonNode.connect(dst.peerInfo.peerId, @[ maddr ])
var stream = await daemonNode.openStream(dst.peerInfo.peerId, protos)
discard await stream.transp.writeLp("line1")
check string.fromBytes(await stream.transp.readLp()) == "line2"
discard await stream.transp.writeLp("line3")
check string.fromBytes(await stream.transp.readLp()) == "line4"
await allFutures(dst.stop(), rel.stop())
await daemonNode.close()
asyncTest "NativeSrc -> DaemonRelay -> NativeDst":
proc customHandler(conn: Connection, proto: string) {.async.} =
check "line1" == string.fromBytes(await conn.readLp(1024))
await conn.writeLp("line2")
check "line3" == string.fromBytes(await conn.readLp(1024))
await conn.writeLp("line4")
await conn.close()
let
protos = @[ "/customProto", RelayCodec ]
var
customProto = new LPProtocol
customProto.handler = customHandler
customProto.codec = protos[0]
let
maSrc = MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()
maDst = MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()
src = SwitchBuilder.new()
.withRng(crypto.newRng())
.withAddresses(@[ maSrc ])
.withTcpTransport()
.withMplex()
.withNoise()
.withRelayTransport(false)
.build()
dst = SwitchBuilder.new()
.withRng(crypto.newRng())
.withAddresses(@[ maDst ])
.withTcpTransport()
.withMplex()
.withNoise()
.withRelayTransport(false)
.build()
dst.mount(customProto)
await src.start()
await dst.start()
let daemonNode = await newDaemonApi({RelayHop})
let daemonPeer = await daemonNode.identity()
let maStr = $daemonPeer.addresses[0] & "/p2p/" & $daemonPeer.peer & "/p2p-circuit/p2p/" & $dst.peerInfo.peerId
let maddr = MultiAddress.init(maStr).tryGet()
await src.connect(daemonPeer.peer, daemonPeer.addresses)
await daemonNode.connect(dst.peerInfo.peerId, dst.peerInfo.addrs)
let conn = await src.dial(dst.peerInfo.peerId, @[ maddr ], protos[0])
await conn.writeLp("line1")
check string.fromBytes(await conn.readLp(1024)) == "line2"
await conn.writeLp("line3")
check string.fromBytes(await conn.readLp(1024)) == "line4"
await allFutures(src.stop(), dst.stop())
await daemonNode.close()

View File

@ -32,4 +32,5 @@ import testtcptransport,
testpeerinfo,
testpeerstore,
testping,
testmplex
testmplex,
testrelay

350
tests/testrelay.nim Normal file
View File

@ -0,0 +1,350 @@
{.used.}
import options, bearssl, chronos
import stew/byteutils
import ../libp2p/[protocols/relay,
multiaddress,
peerinfo,
peerid,
stream/connection,
multistream,
transports/transport,
switch,
builders,
upgrademngrs/upgrade,
varint,
daemon/daemonapi]
import ./helpers
proc new(T: typedesc[RelayTransport], relay: Relay): T =
T.new(relay = relay, upgrader = relay.switch.transports[0].upgrader)
proc writeLp*(s: StreamTransport, msg: string | seq[byte]): Future[int] {.gcsafe.} =
## write lenght prefixed
var buf = initVBuffer()
buf.writeSeq(msg)
buf.finish()
result = s.write(buf.buffer)
proc readLp*(s: StreamTransport): Future[seq[byte]] {.async, gcsafe.} =
## read length prefixed msg
var
size: uint
length: int
res: VarintResult[void]
result = newSeq[byte](10)
for i in 0..<len(result):
await s.readExactly(addr result[i], 1)
res = LP.getUVarint(result.toOpenArray(0, i), length, size)
if res.isOk():
break
res.expect("Valid varint")
result.setLen(size)
if size > 0.uint:
await s.readExactly(addr result[0], int(size))
suite "Circuit Relay":
asyncTeardown:
await allFutures(src.stop(), dst.stop(), rel.stop())
checkTrackers()
var
protos {.threadvar.}: seq[string]
customProto {.threadvar.}: LPProtocol
ma {.threadvar.}: MultiAddress
src {.threadvar.}: Switch
dst {.threadvar.}: Switch
rel {.threadvar.}: Switch
relaySrc {.threadvar.}: Relay
relayDst {.threadvar.}: Relay
relayRel {.threadvar.}: Relay
conn {.threadVar.}: Connection
msg {.threadVar.}: ProtoBuffer
rcv {.threadVar.}: Option[RelayMessage]
proc createMsg(
msgType: Option[RelayType] = RelayType.none,
status: Option[RelayStatus] = RelayStatus.none,
src: Option[RelayPeer] = RelayPeer.none,
dst: Option[RelayPeer] = RelayPeer.none): ProtoBuffer =
encodeMsg(RelayMessage(msgType: msgType, srcPeer: src, dstPeer: dst, status: status))
proc checkMsg(msg: Option[RelayMessage],
msgType: Option[RelayType] = none[RelayType](),
status: Option[RelayStatus] = none[RelayStatus](),
src: Option[RelayPeer] = none[RelayPeer](),
dst: Option[RelayPeer] = none[RelayPeer]()): bool =
msg.isSome and msg.get == RelayMessage(msgType: msgType, srcPeer: src, dstPeer: dst, status: status)
proc customHandler(conn: Connection, proto: string) {.async.} =
check "line1" == string.fromBytes(await conn.readLp(1024))
await conn.writeLp("line2")
check "line3" == string.fromBytes(await conn.readLp(1024))
await conn.writeLp("line4")
await conn.close()
asyncSetup:
# Create a custom prototype
protos = @[ "/customProto", RelayCodec ]
customProto = new LPProtocol
customProto.handler = customHandler
customProto.codec = protos[0]
ma = MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()
src = newStandardSwitch()
rel = newStandardSwitch()
dst = SwitchBuilder
.new()
.withRng(newRng())
.withAddresses(@[ ma ])
.withTcpTransport()
.withMplex()
.withNoise()
.build()
relaySrc = Relay.new(src, false)
relayDst = Relay.new(dst, false)
relayRel = Relay.new(rel, true)
src.mount(relaySrc)
dst.mount(relayDst)
dst.mount(customProto)
rel.mount(relayRel)
src.addTransport(RelayTransport.new(relaySrc))
dst.addTransport(RelayTransport.new(relayDst))
await src.start()
await dst.start()
await rel.start()
asyncTest "Handle CanHop":
msg = createMsg(some(CanHop))
conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec)
await conn.writeLp(msg.buffer)
rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize))
check rcv.checkMsg(some(Status), some(RelayStatus.Success))
conn = await src.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayCodec)
await conn.writeLp(msg.buffer)
rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize))
check rcv.checkMsg(some(Status), some(HopCantSpeakRelay))
await conn.close()
asyncTest "Malformed":
conn = await rel.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayCodec)
msg = createMsg(some(RelayType.Status))
await conn.writeLp(msg.buffer)
rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize))
await conn.close()
check rcv.checkMsg(some(Status), some(MalformedMessage))
asyncTest "Handle Stop Error":
conn = await rel.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayCodec)
msg = createMsg(some(RelayType.Stop),
none(RelayStatus),
none(RelayPeer),
some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs)))
await conn.writeLp(msg.buffer)
rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize))
check rcv.checkMsg(some(Status), some(StopSrcMultiaddrInvalid))
conn = await rel.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayCodec)
msg = createMsg(some(RelayType.Stop),
none(RelayStatus),
some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)),
none(RelayPeer))
await conn.writeLp(msg.buffer)
rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize))
check rcv.checkMsg(some(Status), some(StopDstMultiaddrInvalid))
conn = await rel.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayCodec)
msg = createMsg(some(RelayType.Stop),
none(RelayStatus),
some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs)),
some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)))
await conn.writeLp(msg.buffer)
rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize))
await conn.close()
check rcv.checkMsg(some(Status), some(StopDstMultiaddrInvalid))
asyncTest "Handle Hop Error":
conn = await src.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayCodec)
msg = createMsg(some(RelayType.Hop))
await conn.writeLp(msg.buffer)
rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize))
check rcv.checkMsg(some(Status), some(HopCantSpeakRelay))
conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec)
msg = createMsg(some(RelayType.Hop),
none(RelayStatus),
none(RelayPeer),
some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs)))
await conn.writeLp(msg.buffer)
rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize))
check rcv.checkMsg(some(Status), some(HopSrcMultiaddrInvalid))
conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec)
msg = createMsg(some(RelayType.Hop),
none(RelayStatus),
some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs)),
some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs)))
await conn.writeLp(msg.buffer)
rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize))
check rcv.checkMsg(some(Status), some(HopSrcMultiaddrInvalid))
conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec)
msg = createMsg(some(RelayType.Hop),
none(RelayStatus),
some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)),
none(RelayPeer))
await conn.writeLp(msg.buffer)
rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize))
check rcv.checkMsg(some(Status), some(HopDstMultiaddrInvalid))
conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec)
msg = createMsg(some(RelayType.Hop),
none(RelayStatus),
some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)),
some(RelayPeer(peerId: rel.peerInfo.peerId, addrs: rel.peerInfo.addrs)))
await conn.writeLp(msg.buffer)
rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize))
check rcv.checkMsg(some(Status), some(HopCantRelayToSelf))
conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec)
msg = createMsg(some(RelayType.Hop),
none(RelayStatus),
some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)),
some(RelayPeer(peerId: rel.peerInfo.peerId, addrs: rel.peerInfo.addrs)))
await conn.writeLp(msg.buffer)
rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize))
check rcv.checkMsg(some(Status), some(HopCantRelayToSelf))
conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec)
msg = createMsg(some(RelayType.Hop),
none(RelayStatus),
some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)),
some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs)))
await conn.writeLp(msg.buffer)
rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize))
check rcv.checkMsg(some(Status), some(HopNoConnToDst))
await rel.connect(dst.peerInfo.peerId, dst.peerInfo.addrs)
relayRel.maxCircuit = 0
conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec)
await conn.writeLp(msg.buffer)
rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize))
check rcv.checkMsg(some(Status), some(HopCantSpeakRelay))
relayRel.maxCircuit = relay.MaxCircuit
await conn.close()
relayRel.maxCircuitPerPeer = 0
conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec)
await conn.writeLp(msg.buffer)
rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize))
check rcv.checkMsg(some(Status), some(HopCantSpeakRelay))
relayRel.maxCircuitPerPeer = relay.MaxCircuitPerPeer
await conn.close()
let dst2 = newStandardSwitch()
await dst2.start()
await rel.connect(dst2.peerInfo.peerId, dst2.peerInfo.addrs)
conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec)
msg = createMsg(some(RelayType.Hop),
none(RelayStatus),
some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)),
some(RelayPeer(peerId: dst2.peerInfo.peerId, addrs: dst2.peerInfo.addrs)))
await conn.writeLp(msg.buffer)
rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize))
check rcv.checkMsg(some(Status), some(HopCantDialDst))
await allFutures(dst2.stop())
asyncTest "Dial Peer":
let maStr = $rel.peerInfo.addrs[0] & "/p2p/" & $rel.peerInfo.peerId & "/p2p-circuit/p2p/" & $dst.peerInfo.peerId
let maddr = MultiAddress.init(maStr).tryGet()
await src.connect(rel.peerInfo.peerId, rel.peerInfo.addrs)
await rel.connect(dst.peerInfo.peerId, dst.peerInfo.addrs)
conn = await src.dial(dst.peerInfo.peerId, @[ maddr ], protos[0])
await conn.writeLp("line1")
check string.fromBytes(await conn.readLp(1024)) == "line2"
await conn.writeLp("line3")
check string.fromBytes(await conn.readLp(1024)) == "line4"
asyncTest "SwitchBuilder withRelay":
let
maSrc = MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()
maRel = MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()
maDst = MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()
srcWR = SwitchBuilder.new()
.withRng(newRng())
.withAddresses(@[ maSrc ])
.withTcpTransport()
.withMplex()
.withNoise()
.withRelayTransport(false)
.build()
relWR = SwitchBuilder.new()
.withRng(newRng())
.withAddresses(@[ maRel ])
.withTcpTransport()
.withMplex()
.withNoise()
.withRelayTransport(true)
.build()
dstWR = SwitchBuilder.new()
.withRng(newRng())
.withAddresses(@[ maDst ])
.withTcpTransport()
.withMplex()
.withNoise()
.withRelayTransport(false)
.build()
dstWR.mount(customProto)
await srcWR.start()
await dstWR.start()
await relWR.start()
let maStr = $relWR.peerInfo.addrs[0] & "/p2p/" & $relWR.peerInfo.peerId & "/p2p-circuit/p2p/" & $dstWR.peerInfo.peerId
let maddr = MultiAddress.init(maStr).tryGet()
await srcWR.connect(relWR.peerInfo.peerId, relWR.peerInfo.addrs)
await relWR.connect(dstWR.peerInfo.peerId, dstWR.peerInfo.addrs)
conn = await srcWR.dial(dstWR.peerInfo.peerId, @[ maddr ], protos[0])
await conn.writeLp("line1")
check string.fromBytes(await conn.readLp(1024)) == "line2"
await conn.writeLp("line3")
check string.fromBytes(await conn.readLp(1024)) == "line4"
await allFutures(srcWR.stop(), dstWR.stop(), relWR.stop())
asynctest "Bad MultiAddress":
await src.connect(rel.peerInfo.peerId, rel.peerInfo.addrs)
await rel.connect(dst.peerInfo.peerId, dst.peerInfo.addrs)
expect(CatchableError):
let maStr = $rel.peerInfo.addrs[0] & "/p2p/" & $rel.peerInfo.peerId & "/p2p/" & $dst.peerInfo.peerId
let maddr = MultiAddress.init(maStr).tryGet()
conn = await src.dial(dst.peerInfo.peerId, @[ maddr ], protos[0])
expect(CatchableError):
let maStr = $rel.peerInfo.addrs[0] & "/p2p/" & $rel.peerInfo.peerId
let maddr = MultiAddress.init(maStr).tryGet()
conn = await src.dial(dst.peerInfo.peerId, @[ maddr ], protos[0])
expect(CatchableError):
let maStr = "/ip4/127.0.0.1"
let maddr = MultiAddress.init(maStr).tryGet()
conn = await src.dial(dst.peerInfo.peerId, @[ maddr ], protos[0])
expect(CatchableError):
let maStr = $dst.peerInfo.peerId
let maddr = MultiAddress.init(maStr).tryGet()
conn = await src.dial(dst.peerInfo.peerId, @[ maddr ], protos[0])