diff --git a/examples/circuitrelay.nim b/examples/circuitrelay.nim new file mode 100644 index 000000000..290c47399 --- /dev/null +++ b/examples/circuitrelay.nim @@ -0,0 +1,76 @@ +import chronos, stew/byteutils +import ../libp2p, + ../libp2p/protocols/relay/[relay, client] + +# Helper to create a circuit relay node +proc createCircuitRelaySwitch(r: Relay): Switch = + SwitchBuilder.new() + .withRng(newRng()) + .withAddresses(@[ MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet() ]) + .withTcpTransport() + .withMplex() + .withNoise() + .withCircuitRelay(r) + .build() + +proc main() {.async.} = + # Create a custom protocol + let customProtoCodec = "/test" + var proto = new LPProtocol + proto.codec = customProtoCodec + proto.handler = proc(conn: Connection, proto: string) {.async.} = + var msg = string.fromBytes(await conn.readLp(1024)) + echo "1 - Dst Received: ", msg + assert "test1" == msg + await conn.writeLp("test2") + msg = string.fromBytes(await conn.readLp(1024)) + echo "2 - Dst Received: ", msg + assert "test3" == msg + await conn.writeLp("test4") + + let + relay = Relay.new() + clSrc = RelayClient.new() + clDst = RelayClient.new() + + # Create three hosts, enable relay client on two of them. + # The third one can relay connections for other peers. + # RelayClient can use a relay, Relay is a relay. + swRel = createCircuitRelaySwitch(relay) + swSrc = createCircuitRelaySwitch(clSrc) + swDst = createCircuitRelaySwitch(clDst) + + # Create a relay address to swDst using swRel as the relay + addrs = MultiAddress.init($swRel.peerInfo.addrs[0] & "/p2p/" & + $swRel.peerInfo.peerId & "/p2p-circuit/p2p/" & + $swDst.peerInfo.peerId).get() + + swDst.mount(proto) + + await swRel.start() + await swSrc.start() + await swDst.start() + + # Connect both Src and Dst to the relay, but not to each other. + await swSrc.connect(swRel.peerInfo.peerId, swRel.peerInfo.addrs) + await swDst.connect(swRel.peerInfo.peerId, swRel.peerInfo.addrs) + + # Dst reserve a slot on the relay. + let rsvp = await clDst.reserve(swRel.peerInfo.peerId, swRel.peerInfo.addrs) + + # Src dial Dst using the relay + let conn = await swSrc.dial(swDst.peerInfo.peerId, @[ addrs ], customProtoCodec) + + await conn.writeLp("test1") + var msg = string.fromBytes(await conn.readLp(1024)) + echo "1 - Src Received: ", msg + assert "test2" == msg + await conn.writeLp("test3") + msg = string.fromBytes(await conn.readLp(1024)) + echo "2 - Src Received: ", msg + assert "test4" == msg + + await relay.stop() + await allFutures(swSrc.stop(), swDst.stop(), swRel.stop()) + +waitFor(main()) diff --git a/libp2p.nimble b/libp2p.nimble index 5572d2527..26925ee2b 100644 --- a/libp2p.nimble +++ b/libp2p.nimble @@ -103,6 +103,7 @@ task test_slim, "Runs the (slimmed down) test suite": task examples_build, "Build the samples": buildSample("directchat") buildSample("helloworld", true) + buildSample("circuitrelay", true) buildTutorial("examples/tutorial_1_connect.md") buildTutorial("examples/tutorial_2_customproto.md") diff --git a/libp2p/builders.nim b/libp2p/builders.nim index 1b0bc3a4e..30d8a05db 100644 --- a/libp2p/builders.nim +++ b/libp2p/builders.nim @@ -23,7 +23,8 @@ import switch, peerid, peerinfo, stream/connection, multiaddress, crypto/crypto, transports/[transport, tcptransport], muxers/[muxer, mplex/mplex, yamux/yamux], - protocols/[identify, secure/secure, secure/noise, relay], + protocols/[identify, secure/secure, secure/noise], + protocols/relay/[relay, client, rtransport], connmanager, upgrademngrs/muxedupgrade, nameresolving/nameresolver, errors, utility @@ -54,8 +55,7 @@ type agentVersion: string nameResolver: NameResolver peerStoreCapacity: Option[int] - isCircuitRelay: bool - circuitRelayCanHop: bool + circuitRelay: Relay proc new*(T: type[SwitchBuilder]): T {.public.} = ## Creates a SwitchBuilder @@ -73,8 +73,7 @@ proc new*(T: type[SwitchBuilder]): T {.public.} = maxOut: -1, maxConnsPerPeer: MaxConnectionsPerPeer, protoVersion: ProtoVersion, - agentVersion: AgentVersion, - isCircuitRelay: false) + agentVersion: AgentVersion) proc withPrivateKey*(b: SwitchBuilder, privateKey: PrivateKey): SwitchBuilder {.public.} = ## Set the private key of the switch. Will be used to @@ -183,9 +182,8 @@ proc withNameResolver*(b: SwitchBuilder, nameResolver: NameResolver): SwitchBuil b.nameResolver = nameResolver b -proc withRelayTransport*(b: SwitchBuilder, canHop: bool): SwitchBuilder = - b.isCircuitRelay = true - b.circuitRelayCanHop = canHop +proc withCircuitRelay*(b: SwitchBuilder, r: Relay = Relay.new()): SwitchBuilder = + b.circuitRelay = r b proc build*(b: SwitchBuilder): Switch @@ -245,10 +243,11 @@ proc build*(b: SwitchBuilder): Switch nameResolver = b.nameResolver, peerStore = peerStore) - if b.isCircuitRelay: - let relay = Relay.new(switch, b.circuitRelayCanHop) - switch.mount(relay) - switch.addTransport(RelayTransport.new(relay, muxedUpgrade)) + if not isNil(b.circuitRelay): + if b.circuitRelay of RelayClient: + switch.addTransport(RelayTransport.new(RelayClient(b.circuitRelay), muxedUpgrade)) + b.circuitRelay.setup(switch) + switch.mount(b.circuitRelay) return switch diff --git a/libp2p/multistream.nim b/libp2p/multistream.nim index 42aef1530..c44f9564d 100644 --- a/libp2p/multistream.nim +++ b/libp2p/multistream.nim @@ -76,7 +76,7 @@ proc select*(m: MultistreamSelect, trace "reading first requested proto", conn if s == proto[0]: trace "successfully selected ", conn, proto = proto[0] - conn.tag = proto[0] + conn.protocol = proto[0] return proto[0] elif proto.len > 1: # Try to negotiate alternatives @@ -89,7 +89,7 @@ proc select*(m: MultistreamSelect, validateSuffix(s) if s == p: trace "selected protocol", conn, protocol = s - conn.tag = s + conn.protocol = s return s return "" else: @@ -167,7 +167,7 @@ proc handle*(m: MultistreamSelect, conn: Connection, active: bool = false) {.asy if (not isNil(h.match) and h.match(ms)) or h.protos.contains(ms): trace "found handler", conn, protocol = ms await conn.writeLp(ms & "\n") - conn.tag = ms + conn.protocol = ms await h.protocol.handler(conn, ms) return debug "no handlers", conn, protocol = ms diff --git a/libp2p/muxers/mplex/lpchannel.nim b/libp2p/muxers/mplex/lpchannel.nim index 758e1cb89..a15191501 100644 --- a/libp2p/muxers/mplex/lpchannel.nim +++ b/libp2p/muxers/mplex/lpchannel.nim @@ -168,8 +168,8 @@ method readOnce*(s: LPChannel, try: let bytes = await procCall BufferStream(s).readOnce(pbytes, nbytes) when defined(libp2p_network_protocols_metrics): - if s.tag.len > 0: - libp2p_protocols_bytes.inc(bytes.int64, labelValues=[s.tag, "in"]) + if s.protocol.len > 0: + libp2p_protocols_bytes.inc(bytes.int64, labelValues=[s.protocol, "in"]) trace "readOnce", s, bytes if bytes == 0: @@ -219,8 +219,8 @@ proc completeWrite( await fut when defined(libp2p_network_protocol_metrics): - if s.tag.len > 0: - libp2p_protocols_bytes.inc(msgLen.int64, labelValues=[s.tag, "out"]) + if s.protocol.len > 0: + libp2p_protocols_bytes.inc(msgLen.int64, labelValues=[s.protocol, "out"]) s.activity = true except CancelledError as exc: @@ -254,6 +254,8 @@ method write*(s: LPChannel, msg: seq[byte]): Future[void] = s.completeWrite(fut, msg.len) +method getWrapped*(s: LPChannel): Connection = s.conn + proc init*( L: type LPChannel, id: uint64, diff --git a/libp2p/protocols/pubsub/pubsub.nim b/libp2p/protocols/pubsub/pubsub.nim index 197b8c891..75ab36370 100644 --- a/libp2p/protocols/pubsub/pubsub.nim +++ b/libp2p/protocols/pubsub/pubsub.nim @@ -286,8 +286,8 @@ proc getOrCreatePeer*( p.peers.withValue(peerId, peer): return peer[] - proc getConn(): Future[Connection] = - p.switch.dial(peerId, protos) + proc getConn(): Future[Connection] {.async.} = + return await p.switch.dial(peerId, protos) proc dropConn(peer: PubSubPeer) = proc dropConnAsync(peer: PubSubPeer) {.async.} = diff --git a/libp2p/protocols/relay.nim b/libp2p/protocols/relay.nim deleted file mode 100644 index b6f2dc957..000000000 --- a/libp2p/protocols/relay.nim +++ /dev/null @@ -1,489 +0,0 @@ -# Nim-LibP2P -# Copyright (c) 2022 Status Research & Development GmbH -# Licensed under either of -# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) -# * MIT license ([LICENSE-MIT](LICENSE-MIT)) -# at your option. -# This file may not be copied, modified, or distributed except according to -# those terms. - -{.push raises: [Defect].} - -import options -import sequtils, strutils, tables -import chronos, chronicles - -import ../peerinfo, - ../switch, - ../multiaddress, - ../stream/connection, - ../protocols/protocol, - ../transports/transport, - ../utility, - ../errors - -const - RelayCodec* = "/libp2p/circuit/relay/0.1.0" - MsgSize* = 4096 - MaxCircuit* = 1024 - MaxCircuitPerPeer* = 64 - -logScope: - topics = "libp2p relay" - -type - RelayType* = enum - Hop = 1 - Stop = 2 - Status = 3 - CanHop = 4 - RelayStatus* = enum - Success = 100 - HopSrcAddrTooLong = 220 - HopDstAddrTooLong = 221 - HopSrcMultiaddrInvalid = 250 - HopDstMultiaddrInvalid = 251 - HopNoConnToDst = 260 - HopCantDialDst = 261 - HopCantOpenDstStream = 262 - HopCantSpeakRelay = 270 - HopCantRelayToSelf = 280 - StopSrcAddrTooLong = 320 - StopDstAddrTooLong = 321 - StopSrcMultiaddrInvalid = 350 - StopDstMultiaddrInvalid = 351 - StopRelayRefused = 390 - MalformedMessage = 400 - - RelayError* = object of LPError - - RelayPeer* = object - peerId*: PeerId - addrs*: seq[MultiAddress] - - AddConn* = proc(conn: Connection): Future[void] {.gcsafe, raises: [Defect].} - - RelayMessage* = object - msgType*: Option[RelayType] - srcPeer*: Option[RelayPeer] - dstPeer*: Option[RelayPeer] - status*: Option[RelayStatus] - - Relay* = ref object of LPProtocol - switch*: Switch - peerId: PeerId - dialer: Dial - canHop: bool - streamCount: int - hopCount: CountTable[PeerId] - - addConn: AddConn - - maxCircuit*: int - maxCircuitPerPeer*: int - msgSize*: int - -proc encodeMsg*(msg: RelayMessage): ProtoBuffer = - result = initProtoBuffer() - - if isSome(msg.msgType): - result.write(1, msg.msgType.get().ord.uint) - if isSome(msg.srcPeer): - var peer = initProtoBuffer() - peer.write(1, msg.srcPeer.get().peerId) - for ma in msg.srcPeer.get().addrs: - peer.write(2, ma.data.buffer) - peer.finish() - result.write(2, peer.buffer) - if isSome(msg.dstPeer): - var peer = initProtoBuffer() - peer.write(1, msg.dstPeer.get().peerId) - for ma in msg.dstPeer.get().addrs: - peer.write(2, ma.data.buffer) - peer.finish() - result.write(3, peer.buffer) - if isSome(msg.status): - result.write(4, msg.status.get().ord.uint) - - result.finish() - -proc decodeMsg*(buf: seq[byte]): Option[RelayMessage] = - var - rMsg: RelayMessage - msgTypeOrd: uint32 - src: RelayPeer - dst: RelayPeer - statusOrd: uint32 - pbSrc: ProtoBuffer - pbDst: ProtoBuffer - - let - pb = initProtoBuffer(buf) - r1 = pb.getField(1, msgTypeOrd) - r2 = pb.getField(2, pbSrc) - r3 = pb.getField(3, pbDst) - r4 = pb.getField(4, statusOrd) - - if r1.isErr() or r2.isErr() or r3.isErr() or r4.isErr(): - return none(RelayMessage) - - if r2.get() and - (pbSrc.getField(1, src.peerId).isErr() or - pbSrc.getRepeatedField(2, src.addrs).isErr()): - return none(RelayMessage) - - if r3.get() and - (pbDst.getField(1, dst.peerId).isErr() or - pbDst.getRepeatedField(2, dst.addrs).isErr()): - return none(RelayMessage) - - if r1.get(): rMsg.msgType = some(RelayType(msgTypeOrd)) - if r2.get(): rMsg.srcPeer = some(src) - if r3.get(): rMsg.dstPeer = some(dst) - if r4.get(): rMsg.status = some(RelayStatus(statusOrd)) - some(rMsg) - -proc sendStatus*(conn: Connection, code: RelayStatus) {.async, gcsafe.} = - trace "send status", status = $code & "(" & $ord(code) & ")" - let - msg = RelayMessage( - msgType: some(RelayType.Status), - status: some(code)) - pb = encodeMsg(msg) - - await conn.writeLp(pb.buffer) - -proc handleHopStream(r: Relay, conn: Connection, msg: RelayMessage) {.async, gcsafe.} = - r.streamCount.inc() - defer: - r.streamCount.dec() - - if r.streamCount > r.maxCircuit: - trace "refusing connection; too many active circuit" - await sendStatus(conn, RelayStatus.HopCantSpeakRelay) - return - - proc checkMsg(): Result[RelayMessage, RelayStatus] = - if not r.canHop: - return err(RelayStatus.HopCantSpeakRelay) - if msg.srcPeer.isNone: - return err(RelayStatus.HopSrcMultiaddrInvalid) - let src = msg.srcPeer.get() - if src.peerId != conn.peerId: - return err(RelayStatus.HopSrcMultiaddrInvalid) - if msg.dstPeer.isNone: - return err(RelayStatus.HopDstMultiaddrInvalid) - let dst = msg.dstPeer.get() - if dst.peerId == r.switch.peerInfo.peerId: - return err(RelayStatus.HopCantRelayToSelf) - if not r.switch.isConnected(dst.peerId): - trace "relay not connected to dst", dst - return err(RelayStatus.HopNoConnToDst) - ok(msg) - - let check = checkMsg() - if check.isErr: - await sendStatus(conn, check.error()) - return - let - src = msg.srcPeer.get() - dst = msg.dstPeer.get() - - # TODO: if r.acl # access control list - # and not r.acl.AllowHop(src.peerId, dst.peerId) - # sendStatus(conn, RelayStatus.HopCantSpeakRelay) - - r.hopCount.inc(src.peerId) - r.hopCount.inc(dst.peerId) - defer: - r.hopCount.inc(src.peerId, -1) - r.hopCount.inc(dst.peerId, -1) - - if r.hopCount[src.peerId] > r.maxCircuitPerPeer: - trace "refusing connection; too many connection from src", src, dst - await sendStatus(conn, RelayStatus.HopCantSpeakRelay) - return - - if r.hopCount[dst.peerId] > r.maxCircuitPerPeer: - trace "refusing connection; too many connection to dst", src, dst - await sendStatus(conn, RelayStatus.HopCantSpeakRelay) - return - - let connDst = try: - await r.switch.dial(dst.peerId, @[RelayCodec]) - except CancelledError as exc: - raise exc - except CatchableError as exc: - trace "error opening relay stream", dst, exc=exc.msg - await sendStatus(conn, RelayStatus.HopCantDialDst) - return - defer: - await connDst.close() - - let msgToSend = RelayMessage( - msgType: some(RelayType.Stop), - srcPeer: some(src), - dstPeer: some(dst), - status: none(RelayStatus)) - - let msgRcvFromDstOpt = try: - await connDst.writeLp(encodeMsg(msgToSend).buffer) - decodeMsg(await connDst.readLp(r.msgSize)) - except CancelledError as exc: - raise exc - except CatchableError as exc: - trace "error writing stop handshake or reading stop response", exc=exc.msg - await sendStatus(conn, RelayStatus.HopCantOpenDstStream) - return - - if msgRcvFromDstOpt.isNone: - trace "error reading stop response", msg = msgRcvFromDstOpt - await sendStatus(conn, RelayStatus.HopCantOpenDstStream) - return - - let msgRcvFromDst = msgRcvFromDstOpt.get() - if msgRcvFromDst.msgType.isNone or msgRcvFromDst.msgType.get() != RelayType.Status: - trace "unexcepted relay stop response", msgType = msgRcvFromDst.msgType - await sendStatus(conn, RelayStatus.HopCantOpenDstStream) - return - - if msgRcvFromDst.status.isNone or msgRcvFromDst.status.get() != RelayStatus.Success: - trace "relay stop failure", status=msgRcvFromDst.status - await sendStatus(conn, RelayStatus.HopCantOpenDstStream) - return - - await sendStatus(conn, RelayStatus.Success) - - trace "relaying connection", src, dst - - proc bridge(conn: Connection, connDst: Connection) {.async.} = - const bufferSize = 4096 - var - bufSrcToDst: array[bufferSize, byte] - bufDstToSrc: array[bufferSize, byte] - futSrc = conn.readOnce(addr bufSrcToDst[0], bufSrcToDst.high + 1) - futDst = connDst.readOnce(addr bufDstToSrc[0], bufDstToSrc.high + 1) - bytesSendFromSrcToDst = 0 - bytesSendFromDstToSrc = 0 - bufRead: int - - while not conn.closed() and not connDst.closed(): - try: - await futSrc or futDst - if futSrc.finished(): - bufRead = await futSrc - bytesSendFromSrcToDst.inc(bufRead) - await connDst.write(@bufSrcToDst[0..= 3: - result = CircuitRelay.match(sma[^2].get()) and - P2PPattern.match(sma[^1].get()) - trace "Handles return", ma, result - -proc new*(T: typedesc[RelayTransport], relay: Relay, upgrader: Upgrade): T = - result = T(relay: relay, upgrader: upgrader) - result.running = true - result.queue = newAsyncQueue[Connection](0) diff --git a/libp2p/protocols/relay/client.nim b/libp2p/protocols/relay/client.nim new file mode 100644 index 000000000..87811a89b --- /dev/null +++ b/libp2p/protocols/relay/client.nim @@ -0,0 +1,291 @@ +# Nim-LibP2P +# Copyright (c) 2022 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + +{.push raises: [Defect].} + +import times, options + +import chronos, chronicles + +import ./relay, + ./messages, + ./rconn, + ./utils, + ../../peerinfo, + ../../switch, + ../../multiaddress, + ../../stream/connection + + +logScope: + topics = "libp2p relay relay-client" + +const RelayClientMsgSize = 4096 + +type + RelayClientError* = object of LPError + ReservationError* = object of RelayClientError + RelayV1DialError* = object of RelayClientError + RelayV2DialError* = object of RelayClientError + RelayClientAddConn* = proc(conn: Connection, + duration: uint32, + data: uint64): Future[void] {.gcsafe, raises: [Defect].} + RelayClient* = ref object of Relay + onNewConnection*: RelayClientAddConn + canHop: bool + + Rsvp* = object + expire*: uint64 # required, Unix expiration time (UTC) + addrs*: seq[MultiAddress] # relay address for reserving peer + voucher*: Option[Voucher] # optional, reservation voucher + limitDuration*: uint32 # seconds + limitData*: uint64 # bytes + +proc sendStopError(conn: Connection, code: StatusV2) {.async.} = + trace "send stop status", status = $code & " (" & $ord(code) & ")" + let msg = StopMessage(msgType: StopMessageType.Status, status: some(code)) + await conn.writeLp(encode(msg).buffer) + +proc handleRelayedConnect(cl: RelayClient, conn: Connection, msg: StopMessage) {.async.} = + if msg.peer.isNone(): + await sendStopError(conn, MalformedMessage) + return + let + # TODO: check the go version to see in which way this could fail + # it's unclear in the spec + src = msg.peer.get() + limitDuration = msg.limit.duration + limitData = msg.limit.data + msg = StopMessage( + msgType: StopMessageType.Status, + status: some(Ok)) + pb = encode(msg) + + trace "incoming relay connection", src + + if cl.onNewConnection == nil: + await sendStopError(conn, StatusV2.ConnectionFailed) + await conn.close() + return + await conn.writeLp(pb.buffer) + # This sound redundant but the callback could, in theory, be set to nil during + # conn.writeLp so it's safer to double check + if cl.onNewConnection != nil: await cl.onNewConnection(conn, limitDuration, limitData) + else: await conn.close() + +proc reserve*(cl: RelayClient, + peerId: PeerId, + addrs: seq[MultiAddress] = @[]): Future[Rsvp] {.async.} = + let conn = await cl.switch.dial(peerId, addrs, RelayV2HopCodec) + defer: await conn.close() + let + pb = encode(HopMessage(msgType: HopMessageType.Reserve)) + msg = try: + await conn.writeLp(pb.buffer) + HopMessage.decode(await conn.readLp(RelayClientMsgSize)).get() + except CancelledError as exc: + raise exc + except CatchableError as exc: + trace "error writing or reading reservation message", exc=exc.msg + raise newException(ReservationError, exc.msg) + + if msg.msgType != HopMessageType.Status: + raise newException(ReservationError, "Unexpected relay response type") + if msg.status.get(UnexpectedMessage) != Ok: + raise newException(ReservationError, "Reservation failed") + if msg.reservation.isNone(): + raise newException(ReservationError, "Missing reservation information") + + let reservation = msg.reservation.get() + if reservation.expire > int64.high().uint64 or + now().utc > reservation.expire.int64.fromUnix.utc: + raise newException(ReservationError, "Bad expiration date") + result.expire = reservation.expire + result.addrs = reservation.addrs + + if reservation.svoucher.isSome(): + let svoucher = SignedVoucher.decode(reservation.svoucher.get()) + if svoucher.isErr() or svoucher.get().data.relayPeerId != peerId: + raise newException(ReservationError, "Invalid voucher") + result.voucher = some(svoucher.get().data) + + result.limitDuration = msg.limit.duration + result.limitData = msg.limit.data + +proc dialPeerV1*( + cl: RelayClient, + conn: Connection, + dstPeerId: PeerId, + dstAddrs: seq[MultiAddress]): Future[Connection] {.async.} = + var + msg = RelayMessage( + msgType: some(RelayType.Hop), + srcPeer: some(RelayPeer(peerId: cl.switch.peerInfo.peerId, addrs: cl.switch.peerInfo.addrs)), + dstPeer: some(RelayPeer(peerId: dstPeerId, addrs: dstAddrs))) + pb = encode(msg) + + trace "Dial peer", msgSend=msg + + try: + await conn.writeLp(pb.buffer) + except CancelledError as exc: + raise exc + except CatchableError as exc: + trace "error writing hop request", exc=exc.msg + raise exc + + let msgRcvFromRelayOpt = try: + RelayMessage.decode(await conn.readLp(RelayClientMsgSize)) + except CancelledError as exc: + raise exc + except CatchableError as exc: + trace "error reading stop response", exc=exc.msg + await sendStatus(conn, StatusV1.HopCantOpenDstStream) + raise exc + + try: + if msgRcvFromRelayOpt.isNone: + raise newException(RelayV1DialError, "Hop can't open destination stream") + let msgRcvFromRelay = msgRcvFromRelayOpt.get() + if msgRcvFromRelay.msgType.isNone or msgRcvFromRelay.msgType.get() != RelayType.Status: + raise newException(RelayV1DialError, "Hop can't open destination stream: wrong message type") + if msgRcvFromRelay.status.isNone or msgRcvFromRelay.status.get() != StatusV1.Success: + raise newException(RelayV1DialError, "Hop can't open destination stream: status failed") + except RelayV1DialError as exc: + await sendStatus(conn, StatusV1.HopCantOpenDstStream) + raise exc + result = conn + +proc dialPeerV2*( + cl: RelayClient, + conn: RelayConnection, + dstPeerId: PeerId, + dstAddrs: seq[MultiAddress]): Future[Connection] {.async.} = + let + p = Peer(peerId: dstPeerId, addrs: dstAddrs) + pb = encode(HopMessage(msgType: HopMessageType.Connect, peer: some(p))) + + trace "Dial peer", p + + let msgRcvFromRelay = try: + await conn.writeLp(pb.buffer) + HopMessage.decode(await conn.readLp(RelayClientMsgSize)).get() + except CancelledError as exc: + raise exc + except CatchableError as exc: + trace "error reading stop response", exc=exc.msg + raise newException(RelayV2DialError, exc.msg) + + if msgRcvFromRelay.msgType != HopMessageType.Status: + raise newException(RelayV2DialError, "Unexpected stop response") + if msgRcvFromRelay.status.get(UnexpectedMessage) != Ok: + trace "Relay stop failed", msg = msgRcvFromRelay.status.get() + raise newException(RelayV2DialError, "Relay stop failure") + conn.limitDuration = msgRcvFromRelay.limit.duration + conn.limitData = msgRcvFromRelay.limit.data + return conn + +proc handleStopStreamV2(cl: RelayClient, conn: Connection) {.async, gcsafe.} = + let msgOpt = StopMessage.decode(await conn.readLp(RelayClientMsgSize)) + if msgOpt.isNone(): + await sendHopStatus(conn, MalformedMessage) + return + trace "client circuit relay v2 handle stream", msg = msgOpt.get() + let msg = msgOpt.get() + + if msg.msgType == StopMessageType.Connect: + await cl.handleRelayedConnect(conn, msg) + else: + trace "Unexpected client / relayv2 handshake", msgType=msg.msgType + await sendStopError(conn, MalformedMessage) + +proc handleStop(cl: RelayClient, conn: Connection, msg: RelayMessage) {.async, gcsafe.} = + if msg.srcPeer.isNone: + await sendStatus(conn, StatusV1.StopSrcMultiaddrInvalid) + return + let src = msg.srcPeer.get() + + if msg.dstPeer.isNone: + await sendStatus(conn, StatusV1.StopDstMultiaddrInvalid) + return + + let dst = msg.dstPeer.get() + if dst.peerId != cl.switch.peerInfo.peerId: + await sendStatus(conn, StatusV1.StopDstMultiaddrInvalid) + return + + trace "get a relay connection", src, conn + + if cl.onNewConnection == nil: + await sendStatus(conn, StatusV1.StopRelayRefused) + await conn.close() + return + await sendStatus(conn, StatusV1.Success) + # This sound redundant but the callback could, in theory, be set to nil during + # sendStatus(Success) so it's safer to double check + if cl.onNewConnection != nil: await cl.onNewConnection(conn, 0, 0) + else: await conn.close() + +proc handleStreamV1(cl: RelayClient, conn: Connection) {.async, gcsafe.} = + let msgOpt = RelayMessage.decode(await conn.readLp(RelayClientMsgSize)) + if msgOpt.isNone: + await sendStatus(conn, StatusV1.MalformedMessage) + return + trace "client circuit relay v1 handle stream", msg = msgOpt.get() + let msg = msgOpt.get() + case msg.msgType.get: + of RelayType.Hop: + if cl.canHop: await cl.handleHop(conn, msg) + else: await sendStatus(conn, StatusV1.HopCantSpeakRelay) + of RelayType.Stop: await cl.handleStop(conn, msg) + of RelayType.CanHop: + if cl.canHop: await sendStatus(conn, StatusV1.Success) + else: await sendStatus(conn, StatusV1.HopCantSpeakRelay) + else: + trace "Unexpected relay handshake", msgType=msg.msgType + await sendStatus(conn, StatusV1.MalformedMessage) + +proc new*(T: typedesc[RelayClient], canHop: bool = false, + reservationTTL: times.Duration = DefaultReservationTTL, + limitDuration: uint32 = DefaultLimitDuration, + limitData: uint64 = DefaultLimitData, + heartbeatSleepTime: uint32 = DefaultHeartbeatSleepTime, + maxCircuit: int = MaxCircuit, + maxCircuitPerPeer: int = MaxCircuitPerPeer, + msgSize: int = RelayClientMsgSize, + circuitRelayV1: bool = false): T = + + let cl = T(canHop: canHop, + reservationTTL: reservationTTL, + limit: Limit(duration: limitDuration, data: limitData), + heartbeatSleepTime: heartbeatSleepTime, + maxCircuit: maxCircuit, + maxCircuitPerPeer: maxCircuitPerPeer, + msgSize: msgSize, + isCircuitRelayV1: circuitRelayV1) + proc handleStream(conn: Connection, proto: string) {.async, gcsafe.} = + try: + case proto: + of RelayV1Codec: await cl.handleStreamV1(conn) + of RelayV2StopCodec: await cl.handleStopStreamV2(conn) + of RelayV2HopCodec: await cl.handleHopStreamV2(conn) + except CancelledError as exc: + raise exc + except CatchableError as exc: + trace "exception in client handler", exc = exc.msg, conn + finally: + trace "exiting client handler", conn + await conn.close() + + cl.handler = handleStream + cl.codecs = if cl.canHop: + @[RelayV1Codec, RelayV2HopCodec, RelayV2StopCodec] + else: + @[RelayV1Codec, RelayV2StopCodec] + cl diff --git a/libp2p/protocols/relay/messages.nim b/libp2p/protocols/relay/messages.nim new file mode 100644 index 000000000..af55c8874 --- /dev/null +++ b/libp2p/protocols/relay/messages.nim @@ -0,0 +1,367 @@ +# Nim-LibP2P +# Copyright (c) 2022 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + +{.push raises: [Defect].} + +import options, macros, sequtils +import stew/objects +import ../../peerinfo, + ../../signed_envelope + +# Circuit Relay V1 Message + +type + RelayType* {.pure.} = enum + Hop = 1 + Stop = 2 + Status = 3 + CanHop = 4 + + StatusV1* {.pure.} = enum + Success = 100 + HopSrcAddrTooLong = 220 + HopDstAddrTooLong = 221 + HopSrcMultiaddrInvalid = 250 + HopDstMultiaddrInvalid = 251 + HopNoConnToDst = 260 + HopCantDialDst = 261 + HopCantOpenDstStream = 262 + HopCantSpeakRelay = 270 + HopCantRelayToSelf = 280 + StopSrcAddrTooLong = 320 + StopDstAddrTooLong = 321 + StopSrcMultiaddrInvalid = 350 + StopDstMultiaddrInvalid = 351 + StopRelayRefused = 390 + MalformedMessage = 400 + + RelayPeer* = object + peerId*: PeerId + addrs*: seq[MultiAddress] + + RelayMessage* = object + msgType*: Option[RelayType] + srcPeer*: Option[RelayPeer] + dstPeer*: Option[RelayPeer] + status*: Option[StatusV1] + +proc encode*(msg: RelayMessage): ProtoBuffer = + result = initProtoBuffer() + + if isSome(msg.msgType): + result.write(1, msg.msgType.get().ord.uint) + if isSome(msg.srcPeer): + var peer = initProtoBuffer() + peer.write(1, msg.srcPeer.get().peerId) + for ma in msg.srcPeer.get().addrs: + peer.write(2, ma.data.buffer) + peer.finish() + result.write(2, peer.buffer) + if isSome(msg.dstPeer): + var peer = initProtoBuffer() + peer.write(1, msg.dstPeer.get().peerId) + for ma in msg.dstPeer.get().addrs: + peer.write(2, ma.data.buffer) + peer.finish() + result.write(3, peer.buffer) + if isSome(msg.status): + result.write(4, msg.status.get().ord.uint) + + result.finish() + +proc decode*(_: typedesc[RelayMessage], buf: seq[byte]): Option[RelayMessage] = + var + rMsg: RelayMessage + msgTypeOrd: uint32 + src: RelayPeer + dst: RelayPeer + statusOrd: uint32 + pbSrc: ProtoBuffer + pbDst: ProtoBuffer + + let + pb = initProtoBuffer(buf) + r1 = pb.getField(1, msgTypeOrd) + r2 = pb.getField(2, pbSrc) + r3 = pb.getField(3, pbDst) + r4 = pb.getField(4, statusOrd) + + if r1.isErr() or r2.isErr() or r3.isErr() or r4.isErr(): + return none(RelayMessage) + + if r2.get() and + (pbSrc.getField(1, src.peerId).isErr() or + pbSrc.getRepeatedField(2, src.addrs).isErr()): + return none(RelayMessage) + + if r3.get() and + (pbDst.getField(1, dst.peerId).isErr() or + pbDst.getRepeatedField(2, dst.addrs).isErr()): + return none(RelayMessage) + + if r1.get(): + if msgTypeOrd.int notin RelayType: + return none(RelayMessage) + rMsg.msgType = some(RelayType(msgTypeOrd)) + if r2.get(): rMsg.srcPeer = some(src) + if r3.get(): rMsg.dstPeer = some(dst) + if r4.get(): + if statusOrd.int notin StatusV1: + return none(RelayMessage) + rMsg.status = some(StatusV1(statusOrd)) + some(rMsg) + +# Voucher + +type + Voucher* = object + relayPeerId*: PeerId # peer ID of the relay + reservingPeerId*: PeerId # peer ID of the reserving peer + expiration*: uint64 # UNIX UTC expiration time for the reservation + +proc decode*(_: typedesc[Voucher], buf: seq[byte]): Result[Voucher, ProtoError] = + let pb = initProtoBuffer(buf) + var v = Voucher() + + ? pb.getRequiredField(1, v.relayPeerId) + ? pb.getRequiredField(2, v.reservingPeerId) + ? pb.getRequiredField(3, v.expiration) + + ok(v) + +proc encode*(v: Voucher): seq[byte] = + var pb = initProtoBuffer() + + pb.write(1, v.relayPeerId) + pb.write(2, v.reservingPeerId) + pb.write(3, v.expiration) + + pb.finish() + pb.buffer + +proc init*(T: typedesc[Voucher], + relayPeerId: PeerId, + reservingPeerId: PeerId, + expiration: uint64): T = + T( + relayPeerId = relayPeerId, + reservingPeerId = reservingPeerId, + expiration: expiration + ) + +type SignedVoucher* = SignedPayload[Voucher] + +proc payloadDomain*(_: typedesc[Voucher]): string = "libp2p-relay-rsvp" +proc payloadType*(_: typedesc[Voucher]): seq[byte] = @[ (byte)0x03, (byte)0x02 ] + +proc checkValid*(spr: SignedVoucher): Result[void, EnvelopeError] = + if not spr.data.relayPeerId.match(spr.envelope.publicKey): + err(EnvelopeInvalidSignature) + else: + ok() + +# Circuit Relay V2 Hop Message + +type + Peer* = object + peerId*: PeerId + addrs*: seq[MultiAddress] + Reservation* = object + expire*: uint64 # required, Unix expiration time (UTC) + addrs*: seq[MultiAddress] # relay address for reserving peer + svoucher*: Option[seq[byte]] # optional, reservation voucher + Limit* = object + duration*: uint32 # seconds + data*: uint64 # bytes + + StatusV2* = enum + Ok = 100 + ReservationRefused = 200 + ResourceLimitExceeded = 201 + PermissionDenied = 202 + ConnectionFailed = 203 + NoReservation = 204 + MalformedMessage = 400 + UnexpectedMessage = 401 + HopMessageType* {.pure.} = enum + Reserve = 0 + Connect = 1 + Status = 2 + HopMessage* = object + msgType*: HopMessageType + peer*: Option[Peer] + reservation*: Option[Reservation] + limit*: Limit + status*: Option[StatusV2] + +proc encode*(msg: HopMessage): ProtoBuffer = + var pb = initProtoBuffer() + + pb.write(1, msg.msgType.ord.uint) + if msg.peer.isSome(): + var ppb = initProtoBuffer() + ppb.write(1, msg.peer.get().peerId) + for ma in msg.peer.get().addrs: + ppb.write(2, ma.data.buffer) + ppb.finish() + pb.write(2, ppb.buffer) + if msg.reservation.isSome(): + let rsrv = msg.reservation.get() + var rpb = initProtoBuffer() + rpb.write(1, rsrv.expire) + for ma in rsrv.addrs: + rpb.write(2, ma.data.buffer) + if rsrv.svoucher.isSome(): + rpb.write(3, rsrv.svoucher.get()) + rpb.finish() + pb.write(3, rpb.buffer) + if msg.limit.duration > 0 or msg.limit.data > 0: + var lpb = initProtoBuffer() + if msg.limit.duration > 0: lpb.write(1, msg.limit.duration) + if msg.limit.data > 0: lpb.write(2, msg.limit.data) + lpb.finish() + pb.write(4, lpb.buffer) + if msg.status.isSome(): + pb.write(5, msg.status.get().ord.uint) + + pb.finish() + pb + +proc decode*(_: typedesc[HopMessage], buf: seq[byte]): Option[HopMessage] = + var + msg: HopMessage + msgTypeOrd: uint32 + pbPeer: ProtoBuffer + pbReservation: ProtoBuffer + pbLimit: ProtoBuffer + statusOrd: uint32 + peer: Peer + reservation: Reservation + limit: Limit + res: bool + + let + pb = initProtoBuffer(buf) + r1 = pb.getRequiredField(1, msgTypeOrd) + r2 = pb.getField(2, pbPeer) + r3 = pb.getField(3, pbReservation) + r4 = pb.getField(4, pbLimit) + r5 = pb.getField(5, statusOrd) + + if r1.isErr() or r2.isErr() or r3.isErr() or r4.isErr() or r5.isErr(): + return none(HopMessage) + + if r2.get() and + (pbPeer.getRequiredField(1, peer.peerId).isErr() or + pbPeer.getRepeatedField(2, peer.addrs).isErr()): + return none(HopMessage) + + if r3.get(): + var svoucher: seq[byte] + let rSVoucher = pbReservation.getField(3, svoucher) + if pbReservation.getRequiredField(1, reservation.expire).isErr() or + pbReservation.getRepeatedField(2, reservation.addrs).isErr() or + rSVoucher.isErr(): + return none(HopMessage) + if rSVoucher.get(): reservation.svoucher = some(svoucher) + + if r4.get() and + (pbLimit.getField(1, limit.duration).isErr() or + pbLimit.getField(2, limit.data).isErr()): + return none(HopMessage) + + if not checkedEnumAssign(msg.msgType, msgTypeOrd): + return none(HopMessage) + if r2.get(): msg.peer = some(peer) + if r3.get(): msg.reservation = some(reservation) + if r4.get(): msg.limit = limit + if r5.get(): + if statusOrd.int notin StatusV2: + return none(HopMessage) + msg.status = some(StatusV2(statusOrd)) + some(msg) + +# Circuit Relay V2 Stop Message + +type + StopMessageType* {.pure.} = enum + Connect = 0 + Status = 1 + StopMessage* = object + msgType*: StopMessageType + peer*: Option[Peer] + limit*: Limit + status*: Option[StatusV2] + + +proc encode*(msg: StopMessage): ProtoBuffer = + var pb = initProtoBuffer() + + pb.write(1, msg.msgType.ord.uint) + if msg.peer.isSome(): + var ppb = initProtoBuffer() + ppb.write(1, msg.peer.get().peerId) + for ma in msg.peer.get().addrs: + ppb.write(2, ma.data.buffer) + ppb.finish() + pb.write(2, ppb.buffer) + if msg.limit.duration > 0 or msg.limit.data > 0: + var lpb = initProtoBuffer() + if msg.limit.duration > 0: lpb.write(1, msg.limit.duration) + if msg.limit.data > 0: lpb.write(2, msg.limit.data) + lpb.finish() + pb.write(3, lpb.buffer) + if msg.status.isSome(): + pb.write(4, msg.status.get().ord.uint) + + pb.finish() + pb + +proc decode*(_: typedesc[StopMessage], buf: seq[byte]): Option[StopMessage] = + var + msg: StopMessage + msgTypeOrd: uint32 + pbPeer: ProtoBuffer + pbLimit: ProtoBuffer + statusOrd: uint32 + peer: Peer + limit: Limit + rVoucher: ProtoResult[bool] + res: bool + + let + pb = initProtoBuffer(buf) + r1 = pb.getRequiredField(1, msgTypeOrd) + r2 = pb.getField(2, pbPeer) + r3 = pb.getField(3, pbLimit) + r4 = pb.getField(4, statusOrd) + + if r1.isErr() or r2.isErr() or r3.isErr() or r4.isErr(): + return none(StopMessage) + + if r2.get() and + (pbPeer.getRequiredField(1, peer.peerId).isErr() or + pbPeer.getRepeatedField(2, peer.addrs).isErr()): + return none(StopMessage) + + if r3.get() and + (pbLimit.getField(1, limit.duration).isErr() or + pbLimit.getField(2, limit.data).isErr()): + return none(StopMessage) + + if msgTypeOrd.int notin StopMessageType.low.ord .. StopMessageType.high.ord: + return none(StopMessage) + msg.msgType = StopMessageType(msgTypeOrd) + if r2.get(): msg.peer = some(peer) + if r3.get(): msg.limit = limit + if r4.get(): + if statusOrd.int notin StatusV2: + return none(StopMessage) + msg.status = some(StatusV2(statusOrd)) + some(msg) diff --git a/libp2p/protocols/relay/rconn.nim b/libp2p/protocols/relay/rconn.nim new file mode 100644 index 000000000..c407cd535 --- /dev/null +++ b/libp2p/protocols/relay/rconn.nim @@ -0,0 +1,58 @@ +# Nim-LibP2P +# Copyright (c) 2022 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + +{.push raises: [Defect].} + +import chronos + +import ../../stream/connection + +type + RelayConnection* = ref object of Connection + conn*: Connection + limitDuration*: uint32 + limitData*: uint64 + dataSent*: uint64 + +method readOnce*( + self: RelayConnection, + pbytes: pointer, + nbytes: int): Future[int] {.async.} = + self.activity = true + return await self.conn.readOnce(pbytes, nbytes) + +method write*(self: RelayConnection, msg: seq[byte]): Future[void] {.async.} = + self.dataSent.inc(msg.len) + if self.limitData != 0 and self.dataSent > self.limitData: + await self.close() + return + self.activity = true + await self.conn.write(msg) + +method closeImpl*(self: RelayConnection): Future[void] {.async.} = + await self.conn.close() + await procCall Connection(self).closeImpl() + +method getWrapped*(self: RelayConnection): Connection = self.conn + +proc new*( + T: typedesc[RelayConnection], + conn: Connection, + limitDuration: uint32, + limitData: uint64): T = + let rc = T(conn: conn, limitDuration: limitDuration, limitData: limitData) + rc.initStream() + if limitDuration > 0: + proc checkDurationConnection() {.async.} = + let sleep = sleepAsync(limitDuration.seconds()) + await sleep or conn.join() + if sleep.finished: await conn.close() + else: sleep.cancel() + asyncSpawn checkDurationConnection() + return rc diff --git a/libp2p/protocols/relay/relay.nim b/libp2p/protocols/relay/relay.nim new file mode 100644 index 000000000..394e4bde8 --- /dev/null +++ b/libp2p/protocols/relay/relay.nim @@ -0,0 +1,383 @@ +# Nim-LibP2P +# Copyright (c) 2022 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + +{.push raises: [Defect].} + +import options, sequtils, tables, sugar + +import chronos, chronicles + +import ./messages, + ./rconn, + ./utils, + ../../peerinfo, + ../../switch, + ../../multiaddress, + ../../multicodec, + ../../stream/connection, + ../../protocols/protocol, + ../../transports/transport, + ../../errors, + ../../utils/heartbeat, + ../../signed_envelope + +# TODO: +# * Eventually replace std/times by chronos/timer. Currently chronos/timer +# doesn't offer the possibility to get a datetime in UNIX UTC +# * Eventually add an access control list in the handleReserve, handleConnect +# and handleHop +# * Better reservation management ie find a way to re-reserve when the end is nigh + +import std/times +export chronicles + +const + RelayMsgSize* = 4096 + DefaultReservationTTL* = initDuration(hours = 1) + DefaultLimitDuration* = 120 + DefaultLimitData* = 1 shl 17 + DefaultHeartbeatSleepTime* = 1 + MaxCircuit* = 128 + MaxCircuitPerPeer* = 16 + +logScope: + topics = "libp2p relay" + +type + RelayV2Error* = object of LPError + SendStopError = object of RelayV2Error + +# Relay Side + +type + Relay* = ref object of LPProtocol + switch*: Switch + peerCount: CountTable[PeerId] + + # number of reservation (relayv2) + number of connection (relayv1) + maxCircuit*: int + + maxCircuitPerPeer*: int + msgSize*: int + # RelayV1 + isCircuitRelayV1*: bool + streamCount: int + # RelayV2 + rsvp: Table[PeerId, DateTime] + reservationLoop: Future[void] + reservationTTL*: times.Duration + heartbeatSleepTime*: uint32 + limit*: Limit + +# Relay V2 + +proc createReserveResponse( + r: Relay, + pid: PeerId, + expire: DateTime): Result[HopMessage, CryptoError] = + let + expireUnix = expire.toTime.toUnix.uint64 + v = Voucher(relayPeerId: r.switch.peerInfo.peerId, + reservingPeerId: pid, + expiration: expireUnix) + sv = ? SignedVoucher.init(r.switch.peerInfo.privateKey, v) + ma = ? MultiAddress.init("/p2p/" & $r.switch.peerInfo.peerId).orErr(CryptoError.KeyError) + rsrv = Reservation(expire: expireUnix, + addrs: r.switch.peerInfo.addrs.mapIt( + ? it.concat(ma).orErr(CryptoError.KeyError)), + svoucher: some(? sv.encode)) + msg = HopMessage(msgType: HopMessageType.Status, + reservation: some(rsrv), + limit: r.limit, + status: some(Ok)) + return ok(msg) + +proc isRelayed(conn: Connection): bool = + var wrappedConn = conn + while not isNil(wrappedConn): + if wrappedConn of RelayConnection: + return true + wrappedConn = wrappedConn.getWrapped() + return false + +proc handleReserve(r: Relay, conn: Connection) {.async, gcsafe.} = + if conn.isRelayed(): + trace "reservation attempt over relay connection", pid = conn.peerId + await sendHopStatus(conn, PermissionDenied) + return + + if r.peerCount[conn.peerId] + r.rsvp.len() >= r.maxCircuit: + trace "Too many reservations", pid = conn.peerId + await sendHopStatus(conn, ReservationRefused) + return + let + pid = conn.peerId + expire = now().utc + r.reservationTTL + msg = r.createReserveResponse(pid, expire) + + trace "reserving relay slot for", pid + if msg.isErr(): + trace "error signing the voucher", error = error(msg), pid + return + r.rsvp[pid] = expire + await conn.writeLp(encode(msg.get()).buffer) + +proc handleConnect(r: Relay, + connSrc: Connection, + msg: HopMessage) {.async, gcsafe.} = + if connSrc.isRelayed(): + trace "connection attempt over relay connection" + await sendHopStatus(connSrc, PermissionDenied) + return + if msg.peer.isNone(): + await sendHopStatus(connSrc, MalformedMessage) + return + + let + src = connSrc.peerId + dst = msg.peer.get().peerId + if dst notin r.rsvp: + trace "refusing connection, no reservation", src, dst + await sendHopStatus(connSrc, NoReservation) + return + + r.peerCount.inc(src) + r.peerCount.inc(dst) + defer: + r.peerCount.inc(src, -1) + r.peerCount.inc(dst, -1) + + if r.peerCount[src] > r.maxCircuitPerPeer or + r.peerCount[dst] > r.maxCircuitPerPeer: + trace "too many connections", src = r.peerCount[src], + dst = r.peerCount[dst], + max = r.maxCircuitPerPeer + await sendHopStatus(connSrc, ResourceLimitExceeded) + return + + let connDst = try: + await r.switch.dial(dst, RelayV2StopCodec) + except CancelledError as exc: + raise exc + except CatchableError as exc: + trace "error opening relay stream", dst, exc=exc.msg + await sendHopStatus(connSrc, ConnectionFailed) + return + defer: + await connDst.close() + + proc sendStopMsg() {.async.} = + let stopMsg = StopMessage(msgType: StopMessageType.Connect, + peer: some(Peer(peerId: src, addrs: @[])), + limit: r.limit) + await connDst.writeLp(encode(stopMsg).buffer) + let msg = StopMessage.decode(await connDst.readLp(r.msgSize)).get() + if msg.msgType != StopMessageType.Status: + raise newException(SendStopError, "Unexpected stop response, not a status message") + if msg.status.get(UnexpectedMessage) != Ok: + raise newException(SendStopError, "Relay stop failure") + await connSrc.writeLp(encode(HopMessage(msgType: HopMessageType.Status, + status: some(Ok))).buffer) + try: + await sendStopMsg() + except CancelledError as exc: + raise exc + except CatchableError as exc: + trace "error sending stop message", msg = exc.msg + await sendHopStatus(connSrc, ConnectionFailed) + return + + trace "relaying connection", src, dst + let + rconnSrc = RelayConnection.new(connSrc, r.limit.duration, r.limit.data) + rconnDst = RelayConnection.new(connDst, r.limit.duration, r.limit.data) + defer: + await rconnSrc.close() + await rconnDst.close() + await bridge(rconnSrc, rconnDst) + +proc handleHopStreamV2*(r: Relay, conn: Connection) {.async, gcsafe.} = + let msgOpt = HopMessage.decode(await conn.readLp(r.msgSize)) + if msgOpt.isNone(): + await sendHopStatus(conn, MalformedMessage) + return + trace "relayv2 handle stream", msg = msgOpt.get() + let msg = msgOpt.get() + case msg.msgType: + of HopMessageType.Reserve: await r.handleReserve(conn) + of HopMessageType.Connect: await r.handleConnect(conn, msg) + else: + trace "Unexpected relayv2 handshake", msgType=msg.msgType + await sendHopStatus(conn, MalformedMessage) + +# Relay V1 + +proc handleHop*(r: Relay, connSrc: Connection, msg: RelayMessage) {.async, gcsafe.} = + r.streamCount.inc() + defer: r.streamCount.dec() + if r.streamCount + r.rsvp.len() >= r.maxCircuit: + trace "refusing connection; too many active circuit", streamCount = r.streamCount, rsvp = r.rsvp.len() + await sendStatus(connSrc, StatusV1.HopCantSpeakRelay) + return + + proc checkMsg(): Result[RelayMessage, StatusV1] = + if msg.srcPeer.isNone: + return err(StatusV1.HopSrcMultiaddrInvalid) + let src = msg.srcPeer.get() + if src.peerId != connSrc.peerId: + return err(StatusV1.HopSrcMultiaddrInvalid) + if msg.dstPeer.isNone: + return err(StatusV1.HopDstMultiaddrInvalid) + let dst = msg.dstPeer.get() + if dst.peerId == r.switch.peerInfo.peerId: + return err(StatusV1.HopCantRelayToSelf) + if not r.switch.isConnected(dst.peerId): + trace "relay not connected to dst", dst + return err(StatusV1.HopNoConnToDst) + ok(msg) + let check = checkMsg() + if check.isErr: + await sendStatus(connSrc, check.error()) + return + + let + src = msg.srcPeer.get() + dst = msg.dstPeer.get() + if r.peerCount[src.peerId] >= r.maxCircuitPerPeer or + r.peerCount[dst.peerId] >= r.maxCircuitPerPeer: + trace "refusing connection; too many connection from src or to dst", src, dst + await sendStatus(connSrc, StatusV1.HopCantSpeakRelay) + return + r.peerCount.inc(src.peerId) + r.peerCount.inc(dst.peerId) + defer: + r.peerCount.inc(src.peerId, -1) + r.peerCount.inc(dst.peerId, -1) + + let connDst = try: + await r.switch.dial(dst.peerId, RelayV1Codec) + except CancelledError as exc: + raise exc + except CatchableError as exc: + trace "error opening relay stream", dst, exc=exc.msg + await sendStatus(connSrc, StatusV1.HopCantDialDst) + return + defer: + await connDst.close() + + let msgToSend = RelayMessage( + msgType: some(RelayType.Stop), + srcPeer: some(src), + dstPeer: some(dst)) + + let msgRcvFromDstOpt = try: + await connDst.writeLp(encode(msgToSend).buffer) + RelayMessage.decode(await connDst.readLp(r.msgSize)) + except CancelledError as exc: + raise exc + except CatchableError as exc: + trace "error writing stop handshake or reading stop response", exc=exc.msg + await sendStatus(connSrc, StatusV1.HopCantOpenDstStream) + return + + if msgRcvFromDstOpt.isNone: + trace "error reading stop response", msg = msgRcvFromDstOpt + await sendStatus(connSrc, StatusV1.HopCantOpenDstStream) + return + + let msgRcvFromDst = msgRcvFromDstOpt.get() + if msgRcvFromDst.msgType.get(RelayType.Stop) != RelayType.Status or + msgRcvFromDst.status.get(StatusV1.StopRelayRefused) != StatusV1.Success: + trace "unexcepted relay stop response", msgRcvFromDst + await sendStatus(connSrc, StatusV1.HopCantOpenDstStream) + return + + await sendStatus(connSrc, StatusV1.Success) + trace "relaying connection", src, dst + await bridge(connSrc, connDst) + +proc handleStreamV1(r: Relay, conn: Connection) {.async, gcsafe.} = + let msgOpt = RelayMessage.decode(await conn.readLp(r.msgSize)) + if msgOpt.isNone: + await sendStatus(conn, StatusV1.MalformedMessage) + return + trace "relay handle stream", msg = msgOpt.get() + let msg = msgOpt.get() + case msg.msgType.get: + of RelayType.Hop: await r.handleHop(conn, msg) + of RelayType.Stop: await sendStatus(conn, StatusV1.StopRelayRefused) + of RelayType.CanHop: await sendStatus(conn, StatusV1.Success) + else: + trace "Unexpected relay handshake", msgType=msg.msgType + await sendStatus(conn, StatusV1.MalformedMessage) + +proc setup*(r: Relay, switch: Switch) = + r.switch = switch + r.switch.addPeerEventHandler( + proc (peerId: PeerId, event: PeerEvent) {.async.} = + r.rsvp.del(peerId), + Left) + +proc new*(T: typedesc[Relay], + reservationTTL: times.Duration = DefaultReservationTTL, + limitDuration: uint32 = DefaultLimitDuration, + limitData: uint64 = DefaultLimitData, + heartbeatSleepTime: uint32 = DefaultHeartbeatSleepTime, + maxCircuit: int = MaxCircuit, + maxCircuitPerPeer: int = MaxCircuitPerPeer, + msgSize: int = RelayMsgSize, + circuitRelayV1: bool = false): T = + + let r = T(reservationTTL: reservationTTL, + limit: Limit(duration: limitDuration, data: limitData), + heartbeatSleepTime: heartbeatSleepTime, + maxCircuit: maxCircuit, + maxCircuitPerPeer: maxCircuitPerPeer, + msgSize: msgSize, + isCircuitRelayV1: circuitRelayV1) + + proc handleStream(conn: Connection, proto: string) {.async, gcsafe.} = + try: + case proto: + of RelayV2HopCodec: await r.handleHopStreamV2(conn) + of RelayV1Codec: await r.handleStreamV1(conn) + except CancelledError as exc: + raise exc + except CatchableError as exc: + debug "exception in relayv2 handler", exc = exc.msg, conn + finally: + trace "exiting relayv2 handler", conn + await conn.close() + + r.codecs = if r.isCircuitRelayV1: @[RelayV1Codec] + else: @[RelayV2HopCodec, RelayV1Codec] + r.handler = handleStream + r + +proc deletesReservation(r: Relay) {.async.} = + heartbeat "Reservation timeout", r.heartbeatSleepTime.seconds(): + let n = now().utc + for k in toSeq(r.rsvp.keys): + if n > r.rsvp[k]: + r.rsvp.del(k) + +method start*(r: Relay) {.async.} = + if not r.reservationLoop.isNil: + warn "Starting relay twice" + return + r.reservationLoop = r.deletesReservation() + r.started = true + +method stop*(r: Relay) {.async.} = + if r.reservationLoop.isNil: + warn "Stopping relay without starting it" + return + r.started = false + r.reservationLoop.cancel() + r.reservationLoop = nil diff --git a/libp2p/protocols/relay/rtransport.nim b/libp2p/protocols/relay/rtransport.nim new file mode 100644 index 000000000..15edd4f7e --- /dev/null +++ b/libp2p/protocols/relay/rtransport.nim @@ -0,0 +1,106 @@ +# Nim-LibP2P +# Copyright (c) 2022 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + +{.push raises: [Defect].} + +import sequtils, strutils + +import chronos, chronicles + +import ./client, + ./rconn, + ./utils, + ../../switch, + ../../stream/connection, + ../../transports/transport + +logScope: + topics = "libp2p relay relay-transport" + +type + RelayTransport* = ref object of Transport + client*: RelayClient + queue: AsyncQueue[Connection] + selfRunning: bool + +method start*(self: RelayTransport, ma: seq[MultiAddress]) {.async.} = + if self.selfRunning: + trace "Relay transport already running" + return + + self.client.onNewConnection = proc( + conn: Connection, + duration: uint32 = 0, + data: uint64 = 0) {.async, gcsafe, raises: [Defect].} = + await self.queue.addLast(RelayConnection.new(conn, duration, data)) + await conn.join() + self.selfRunning = true + await procCall Transport(self).start(ma) + trace "Starting Relay transport" + +method stop*(self: RelayTransport) {.async, gcsafe.} = + self.running = false + self.selfRunning = false + self.client.onNewConnection = nil + while not self.queue.empty(): + await self.queue.popFirstNoWait().close() + +method accept*(self: RelayTransport): Future[Connection] {.async, gcsafe.} = + result = await self.queue.popFirst() + +proc dial*(self: RelayTransport, ma: MultiAddress): Future[Connection] {.async, gcsafe.} = + let + sma = toSeq(ma.items()) + relayAddrs = sma[0..sma.len-4].mapIt(it.tryGet()).foldl(a & b) + var + relayPeerId: PeerId + dstPeerId: PeerId + if not relayPeerId.init(($(sma[^3].get())).split('/')[2]): + raise newException(RelayV2DialError, "Relay doesn't exist") + if not dstPeerId.init(($(sma[^1].get())).split('/')[2]): + raise newException(RelayV2DialError, "Destination doesn't exist") + trace "Dial", relayPeerId, dstPeerId + + let conn = await self.client.switch.dial( + relayPeerId, + @[ relayAddrs ], + @[ RelayV2HopCodec, RelayV1Codec ]) + conn.dir = Direction.Out + var rc: RelayConnection + try: + case conn.protocol: + of RelayV1Codec: + return await self.client.dialPeerV1(conn, dstPeerId, @[]) + of RelayV2HopCodec: + rc = RelayConnection.new(conn, 0, 0) + return await self.client.dialPeerV2(rc, dstPeerId, @[]) + except CancelledError as exc: + raise exc + except CatchableError as exc: + if not rc.isNil: await rc.close() + raise exc + +method dial*( + self: RelayTransport, + hostname: string, + address: MultiAddress): Future[Connection] {.async, gcsafe.} = + result = await self.dial(address) + +method handles*(self: RelayTransport, ma: MultiAddress): bool {.gcsafe} = + if ma.protocols.isOk(): + let sma = toSeq(ma.items()) + if sma.len >= 3: + result = CircuitRelay.match(sma[^2].get()) and + P2PPattern.match(sma[^1].get()) + trace "Handles return", ma, result + +proc new*(T: typedesc[RelayTransport], cl: RelayClient, upgrader: Upgrade): T = + result = T(client: cl, upgrader: upgrader) + result.running = true + result.queue = newAsyncQueue[Connection](0) diff --git a/libp2p/protocols/relay/utils.nim b/libp2p/protocols/relay/utils.nim new file mode 100644 index 000000000..246676203 --- /dev/null +++ b/libp2p/protocols/relay/utils.nim @@ -0,0 +1,84 @@ +# Nim-LibP2P +# Copyright (c) 2022 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + +{.push raises: [Defect].} + +import options + +import chronos, chronicles + +import ./messages, + ../../stream/connection + +logScope: + topics = "libp2p relay relay-utils" + +const + RelayV1Codec* = "/libp2p/circuit/relay/0.1.0" + RelayV2HopCodec* = "/libp2p/circuit/relay/0.2.0/hop" + RelayV2StopCodec* = "/libp2p/circuit/relay/0.2.0/stop" + +proc sendStatus*(conn: Connection, code: StatusV1) {.async, gcsafe.} = + trace "send relay/v1 status", status = $code & "(" & $ord(code) & ")" + let + msg = RelayMessage(msgType: some(RelayType.Status), status: some(code)) + pb = encode(msg) + await conn.writeLp(pb.buffer) + +proc sendHopStatus*(conn: Connection, code: StatusV2) {.async, gcsafe.} = + trace "send hop relay/v2 status", status = $code & "(" & $ord(code) & ")" + let + msg = HopMessage(msgType: HopMessageType.Status, status: some(code)) + pb = encode(msg) + await conn.writeLp(pb.buffer) + +proc sendStopStatus*(conn: Connection, code: StatusV2) {.async.} = + trace "send stop relay/v2 status", status = $code & " (" & $ord(code) & ")" + let + msg = StopMessage(msgType: StopMessageType.Status, status: some(code)) + pb = encode(msg) + await conn.writeLp(pb.buffer) + +proc bridge*(connSrc: Connection, connDst: Connection) {.async.} = + const bufferSize = 4096 + var + bufSrcToDst: array[bufferSize, byte] + bufDstToSrc: array[bufferSize, byte] + futSrc = connSrc.readOnce(addr bufSrcToDst[0], bufSrcToDst.high + 1) + futDst = connDst.readOnce(addr bufDstToSrc[0], bufDstToSrc.high + 1) + bytesSendFromSrcToDst = 0 + bytesSendFromDstToSrc = 0 + bufRead: int + + try: + while not connSrc.closed() and not connDst.closed(): + await futSrc or futDst + if futSrc.finished(): + bufRead = await futSrc + bytesSendFromSrcToDst.inc(bufRead) + await connDst.write(@bufSrcToDst[0.. 0.uint: - await s.readExactly(addr result[0], int(size)) - -suite "Circuit Relay": - asyncTeardown: - await allFutures(src.stop(), dst.stop(), rel.stop()) - checkTrackers() - - var - protos {.threadvar.}: seq[string] - customProto {.threadvar.}: LPProtocol - ma {.threadvar.}: MultiAddress - src {.threadvar.}: Switch - dst {.threadvar.}: Switch - rel {.threadvar.}: Switch - relaySrc {.threadvar.}: Relay - relayDst {.threadvar.}: Relay - relayRel {.threadvar.}: Relay - conn {.threadvar.}: Connection - msg {.threadvar.}: ProtoBuffer - rcv {.threadvar.}: Option[RelayMessage] - - proc createMsg( - msgType: Option[RelayType] = RelayType.none, - status: Option[RelayStatus] = RelayStatus.none, - src: Option[RelayPeer] = RelayPeer.none, - dst: Option[RelayPeer] = RelayPeer.none): ProtoBuffer = - encodeMsg(RelayMessage(msgType: msgType, srcPeer: src, dstPeer: dst, status: status)) - - proc checkMsg(msg: Option[RelayMessage], - msgType: Option[RelayType] = none[RelayType](), - status: Option[RelayStatus] = none[RelayStatus](), - src: Option[RelayPeer] = none[RelayPeer](), - dst: Option[RelayPeer] = none[RelayPeer]()): bool = - msg.isSome and msg.get == RelayMessage(msgType: msgType, srcPeer: src, dstPeer: dst, status: status) - - proc customHandler(conn: Connection, proto: string) {.async.} = - check "line1" == string.fromBytes(await conn.readLp(1024)) - await conn.writeLp("line2") - check "line3" == string.fromBytes(await conn.readLp(1024)) - await conn.writeLp("line4") - await conn.close() - - asyncSetup: - # Create a custom prototype - protos = @[ "/customProto", RelayCodec ] - customProto = new LPProtocol - customProto.handler = customHandler - customProto.codec = protos[0] - ma = MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet() - - src = newStandardSwitch() - rel = newStandardSwitch() - dst = SwitchBuilder - .new() - .withRng(newRng()) - .withAddresses(@[ ma ]) - .withTcpTransport() - .withMplex() - .withNoise() - .build() - - relaySrc = Relay.new(src, false) - relayDst = Relay.new(dst, false) - relayRel = Relay.new(rel, true) - - src.mount(relaySrc) - dst.mount(relayDst) - dst.mount(customProto) - rel.mount(relayRel) - - src.addTransport(RelayTransport.new(relaySrc)) - dst.addTransport(RelayTransport.new(relayDst)) - - await src.start() - await dst.start() - await rel.start() - - asyncTest "Handle CanHop": - msg = createMsg(some(CanHop)) - conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec) - await conn.writeLp(msg.buffer) - rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize)) - check rcv.checkMsg(some(Status), some(RelayStatus.Success)) - - conn = await src.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayCodec) - await conn.writeLp(msg.buffer) - rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize)) - check rcv.checkMsg(some(Status), some(HopCantSpeakRelay)) - - await conn.close() - - asyncTest "Malformed": - conn = await rel.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayCodec) - msg = createMsg(some(RelayType.Status)) - await conn.writeLp(msg.buffer) - rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize)) - await conn.close() - check rcv.checkMsg(some(Status), some(MalformedMessage)) - - asyncTest "Handle Stop Error": - conn = await rel.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayCodec) - msg = createMsg(some(RelayType.Stop), - none(RelayStatus), - none(RelayPeer), - some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs))) - await conn.writeLp(msg.buffer) - rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize)) - check rcv.checkMsg(some(Status), some(StopSrcMultiaddrInvalid)) - - conn = await rel.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayCodec) - msg = createMsg(some(RelayType.Stop), - none(RelayStatus), - some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)), - none(RelayPeer)) - await conn.writeLp(msg.buffer) - rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize)) - check rcv.checkMsg(some(Status), some(StopDstMultiaddrInvalid)) - - conn = await rel.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayCodec) - msg = createMsg(some(RelayType.Stop), - none(RelayStatus), - some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs)), - some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs))) - await conn.writeLp(msg.buffer) - rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize)) - await conn.close() - check rcv.checkMsg(some(Status), some(StopDstMultiaddrInvalid)) - - asyncTest "Handle Hop Error": - conn = await src.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayCodec) - msg = createMsg(some(RelayType.Hop)) - await conn.writeLp(msg.buffer) - rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize)) - check rcv.checkMsg(some(Status), some(HopCantSpeakRelay)) - - conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec) - msg = createMsg(some(RelayType.Hop), - none(RelayStatus), - none(RelayPeer), - some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs))) - await conn.writeLp(msg.buffer) - rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize)) - check rcv.checkMsg(some(Status), some(HopSrcMultiaddrInvalid)) - - conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec) - msg = createMsg(some(RelayType.Hop), - none(RelayStatus), - some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs)), - some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs))) - await conn.writeLp(msg.buffer) - rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize)) - check rcv.checkMsg(some(Status), some(HopSrcMultiaddrInvalid)) - - conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec) - msg = createMsg(some(RelayType.Hop), - none(RelayStatus), - some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)), - none(RelayPeer)) - await conn.writeLp(msg.buffer) - rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize)) - check rcv.checkMsg(some(Status), some(HopDstMultiaddrInvalid)) - - conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec) - msg = createMsg(some(RelayType.Hop), - none(RelayStatus), - some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)), - some(RelayPeer(peerId: rel.peerInfo.peerId, addrs: rel.peerInfo.addrs))) - await conn.writeLp(msg.buffer) - rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize)) - check rcv.checkMsg(some(Status), some(HopCantRelayToSelf)) - - conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec) - msg = createMsg(some(RelayType.Hop), - none(RelayStatus), - some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)), - some(RelayPeer(peerId: rel.peerInfo.peerId, addrs: rel.peerInfo.addrs))) - await conn.writeLp(msg.buffer) - rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize)) - check rcv.checkMsg(some(Status), some(HopCantRelayToSelf)) - - conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec) - msg = createMsg(some(RelayType.Hop), - none(RelayStatus), - some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)), - some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs))) - await conn.writeLp(msg.buffer) - rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize)) - check rcv.checkMsg(some(Status), some(HopNoConnToDst)) - - await rel.connect(dst.peerInfo.peerId, dst.peerInfo.addrs) - - relayRel.maxCircuit = 0 - conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec) - await conn.writeLp(msg.buffer) - rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize)) - check rcv.checkMsg(some(Status), some(HopCantSpeakRelay)) - relayRel.maxCircuit = relay.MaxCircuit - await conn.close() - - relayRel.maxCircuitPerPeer = 0 - conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec) - await conn.writeLp(msg.buffer) - rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize)) - check rcv.checkMsg(some(Status), some(HopCantSpeakRelay)) - relayRel.maxCircuitPerPeer = relay.MaxCircuitPerPeer - await conn.close() - - let dst2 = newStandardSwitch() - await dst2.start() - await rel.connect(dst2.peerInfo.peerId, dst2.peerInfo.addrs) - - conn = await src.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayCodec) - msg = createMsg(some(RelayType.Hop), - none(RelayStatus), - some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)), - some(RelayPeer(peerId: dst2.peerInfo.peerId, addrs: dst2.peerInfo.addrs))) - await conn.writeLp(msg.buffer) - rcv = relay.decodeMsg(await conn.readLp(relay.MsgSize)) - check rcv.checkMsg(some(Status), some(HopCantDialDst)) - await allFutures(dst2.stop()) - - asyncTest "Dial Peer": - let maStr = $rel.peerInfo.addrs[0] & "/p2p/" & $rel.peerInfo.peerId & "/p2p-circuit/p2p/" & $dst.peerInfo.peerId - let maddr = MultiAddress.init(maStr).tryGet() - await src.connect(rel.peerInfo.peerId, rel.peerInfo.addrs) - await rel.connect(dst.peerInfo.peerId, dst.peerInfo.addrs) - conn = await src.dial(dst.peerInfo.peerId, @[ maddr ], protos[0]) - - await conn.writeLp("line1") - check string.fromBytes(await conn.readLp(1024)) == "line2" - - await conn.writeLp("line3") - check string.fromBytes(await conn.readLp(1024)) == "line4" - - asyncTest "SwitchBuilder withRelay": - let - maSrc = MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet() - maRel = MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet() - maDst = MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet() - srcWR = SwitchBuilder.new() - .withRng(newRng()) - .withAddresses(@[ maSrc ]) - .withTcpTransport() - .withMplex() - .withNoise() - .withRelayTransport(false) - .build() - relWR = SwitchBuilder.new() - .withRng(newRng()) - .withAddresses(@[ maRel ]) - .withTcpTransport() - .withMplex() - .withNoise() - .withRelayTransport(true) - .build() - dstWR = SwitchBuilder.new() - .withRng(newRng()) - .withAddresses(@[ maDst ]) - .withTcpTransport() - .withMplex() - .withNoise() - .withRelayTransport(false) - .build() - - dstWR.mount(customProto) - - await srcWR.start() - await dstWR.start() - await relWR.start() - - let maStr = $relWR.peerInfo.addrs[0] & "/p2p/" & $relWR.peerInfo.peerId & "/p2p-circuit/p2p/" & $dstWR.peerInfo.peerId - let maddr = MultiAddress.init(maStr).tryGet() - await srcWR.connect(relWR.peerInfo.peerId, relWR.peerInfo.addrs) - await relWR.connect(dstWR.peerInfo.peerId, dstWR.peerInfo.addrs) - conn = await srcWR.dial(dstWR.peerInfo.peerId, @[ maddr ], protos[0]) - - await conn.writeLp("line1") - check string.fromBytes(await conn.readLp(1024)) == "line2" - - await conn.writeLp("line3") - check string.fromBytes(await conn.readLp(1024)) == "line4" - - await allFutures(srcWR.stop(), dstWR.stop(), relWR.stop()) - - asyncTest "Bad MultiAddress": - await src.connect(rel.peerInfo.peerId, rel.peerInfo.addrs) - await rel.connect(dst.peerInfo.peerId, dst.peerInfo.addrs) - expect(CatchableError): - let maStr = $rel.peerInfo.addrs[0] & "/p2p/" & $rel.peerInfo.peerId & "/p2p/" & $dst.peerInfo.peerId - let maddr = MultiAddress.init(maStr).tryGet() - conn = await src.dial(dst.peerInfo.peerId, @[ maddr ], protos[0]) - - expect(CatchableError): - let maStr = $rel.peerInfo.addrs[0] & "/p2p/" & $rel.peerInfo.peerId - let maddr = MultiAddress.init(maStr).tryGet() - conn = await src.dial(dst.peerInfo.peerId, @[ maddr ], protos[0]) - - expect(CatchableError): - let maStr = "/ip4/127.0.0.1" - let maddr = MultiAddress.init(maStr).tryGet() - conn = await src.dial(dst.peerInfo.peerId, @[ maddr ], protos[0]) - - expect(CatchableError): - let maStr = $dst.peerInfo.peerId - let maddr = MultiAddress.init(maStr).tryGet() - conn = await src.dial(dst.peerInfo.peerId, @[ maddr ], protos[0]) diff --git a/tests/testrelayv1.nim b/tests/testrelayv1.nim new file mode 100644 index 000000000..17206be5c --- /dev/null +++ b/tests/testrelayv1.nim @@ -0,0 +1,293 @@ +{.used.} + +import options, bearssl, chronos +import stew/byteutils +import ../libp2p/[protocols/relay/relay, + protocols/relay/client, + protocols/relay/messages, + protocols/relay/utils, + protocols/relay/rtransport, + multiaddress, + peerinfo, + peerid, + stream/connection, + multistream, + transports/transport, + switch, + builders, + upgrademngrs/upgrade, + varint, + daemon/daemonapi] +import ./helpers + +proc new(T: typedesc[RelayTransport], relay: Relay): T = + T.new(relay = relay, upgrader = relay.switch.transports[0].upgrader) + +suite "Circuit Relay": + asyncTeardown: + await allFutures(src.stop(), dst.stop(), srelay.stop()) + checkTrackers() + + var + protos {.threadvar.}: seq[string] + customProto {.threadvar.}: LPProtocol + ma {.threadvar.}: MultiAddress + src {.threadvar.}: Switch + dst {.threadvar.}: Switch + srelay {.threadvar.}: Switch + clSrc {.threadvar.}: RelayClient + clDst {.threadvar.}: RelayClient + r {.threadvar.}: Relay + conn {.threadvar.}: Connection + msg {.threadvar.}: ProtoBuffer + rcv {.threadvar.}: Option[RelayMessage] + + proc createMsg( + msgType: Option[RelayType] = RelayType.none, + status: Option[StatusV1] = StatusV1.none, + src: Option[RelayPeer] = RelayPeer.none, + dst: Option[RelayPeer] = RelayPeer.none): ProtoBuffer = + encode(RelayMessage(msgType: msgType, srcPeer: src, dstPeer: dst, status: status)) + + proc checkMsg(msg: Option[RelayMessage], + msgType: Option[RelayType] = none[RelayType](), + status: Option[StatusV1] = none[StatusV1](), + src: Option[RelayPeer] = none[RelayPeer](), + dst: Option[RelayPeer] = none[RelayPeer]()) = + check: msg.isSome + let m = msg.get() + check: m.msgType == msgType + check: m.status == status + check: m.srcPeer == src + check: m.dstPeer == dst + + proc customHandler(conn: Connection, proto: string) {.async.} = + check "line1" == string.fromBytes(await conn.readLp(1024)) + await conn.writeLp("line2") + check "line3" == string.fromBytes(await conn.readLp(1024)) + await conn.writeLp("line4") + await conn.close() + + asyncSetup: + # Create a custom prototype + protos = @[ "/customProto", RelayV1Codec ] + customProto = new LPProtocol + customProto.handler = customHandler + customProto.codec = protos[0] + + ma = MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet() + clSrc = RelayClient.new() + clDst = RelayClient.new() + r = Relay.new(circuitRelayV1=true) + src = SwitchBuilder.new() + .withRng(newRng()) + .withAddresses(@[ MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet() ]) + .withTcpTransport() + .withMplex() + .withNoise() + .withCircuitRelay(clSrc) + .build() + dst = SwitchBuilder.new() + .withRng(newRng()) + .withAddresses(@[ MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet() ]) + .withTcpTransport() + .withMplex() + .withNoise() + .withCircuitRelay(clDst) + .build() + srelay = SwitchBuilder.new() + .withRng(newRng()) + .withAddresses(@[ MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet() ]) + .withTcpTransport() + .withMplex() + .withNoise() + .withCircuitRelay(r) + .build() + + dst.mount(customProto) + + await src.start() + await dst.start() + await srelay.start() + + asyncTest "Handle CanHop": + msg = createMsg(some(CanHop)) + conn = await src.dial(srelay.peerInfo.peerId, srelay.peerInfo.addrs, RelayV1Codec) + await conn.writeLp(msg.buffer) + rcv = RelayMessage.decode(await conn.readLp(1024)) + rcv.checkMsg(some(RelayType.Status), some(StatusV1.Success)) + + conn = await src.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayV1Codec) + await conn.writeLp(msg.buffer) + rcv = RelayMessage.decode(await conn.readLp(1024)) + rcv.checkMsg(some(RelayType.Status), some(HopCantSpeakRelay)) + + await conn.close() + + asyncTest "Malformed": + conn = await srelay.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayV1Codec) + msg = createMsg(some(RelayType.Status)) + await conn.writeLp(msg.buffer) + rcv = RelayMessage.decode(await conn.readLp(1024)) + await conn.close() + rcv.checkMsg(some(RelayType.Status), some(StatusV1.MalformedMessage)) + + asyncTest "Handle Stop Error": + conn = await srelay.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayV1Codec) + msg = createMsg(some(RelayType.Stop), + none(StatusV1), + none(RelayPeer), + some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs))) + await conn.writeLp(msg.buffer) + rcv = RelayMessage.decode(await conn.readLp(1024)) + rcv.checkMsg(some(RelayType.Status), some(StopSrcMultiaddrInvalid)) + + conn = await srelay.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayV1Codec) + msg = createMsg(some(RelayType.Stop), + none(StatusV1), + some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)), + none(RelayPeer)) + await conn.writeLp(msg.buffer) + rcv = RelayMessage.decode(await conn.readLp(1024)) + rcv.checkMsg(some(RelayType.Status), some(StopDstMultiaddrInvalid)) + + conn = await srelay.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayV1Codec) + msg = createMsg(some(RelayType.Stop), + none(StatusV1), + some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs)), + some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs))) + await conn.writeLp(msg.buffer) + rcv = RelayMessage.decode(await conn.readLp(1024)) + await conn.close() + rcv.checkMsg(some(RelayType.Status), some(StopDstMultiaddrInvalid)) + + asyncTest "Handle Hop Error": + conn = await src.dial(dst.peerInfo.peerId, dst.peerInfo.addrs, RelayV1Codec) + msg = createMsg(some(RelayType.Hop)) + await conn.writeLp(msg.buffer) + rcv = RelayMessage.decode(await conn.readLp(1024)) + rcv.checkMsg(some(RelayType.Status), some(HopCantSpeakRelay)) + + conn = await src.dial(srelay.peerInfo.peerId, srelay.peerInfo.addrs, RelayV1Codec) + msg = createMsg(some(RelayType.Hop), + none(StatusV1), + none(RelayPeer), + some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs))) + await conn.writeLp(msg.buffer) + rcv = RelayMessage.decode(await conn.readLp(1024)) + rcv.checkMsg(some(RelayType.Status), some(HopSrcMultiaddrInvalid)) + + conn = await src.dial(srelay.peerInfo.peerId, srelay.peerInfo.addrs, RelayV1Codec) + msg = createMsg(some(RelayType.Hop), + none(StatusV1), + some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs)), + some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs))) + await conn.writeLp(msg.buffer) + rcv = RelayMessage.decode(await conn.readLp(1024)) + rcv.checkMsg(some(RelayType.Status), some(HopSrcMultiaddrInvalid)) + + conn = await src.dial(srelay.peerInfo.peerId, srelay.peerInfo.addrs, RelayV1Codec) + msg = createMsg(some(RelayType.Hop), + none(StatusV1), + some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)), + none(RelayPeer)) + await conn.writeLp(msg.buffer) + rcv = RelayMessage.decode(await conn.readLp(1024)) + rcv.checkMsg(some(RelayType.Status), some(HopDstMultiaddrInvalid)) + + conn = await src.dial(srelay.peerInfo.peerId, srelay.peerInfo.addrs, RelayV1Codec) + msg = createMsg(some(RelayType.Hop), + none(StatusV1), + some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)), + some(RelayPeer(peerId: srelay.peerInfo.peerId, addrs: srelay.peerInfo.addrs))) + await conn.writeLp(msg.buffer) + rcv = RelayMessage.decode(await conn.readLp(1024)) + rcv.checkMsg(some(RelayType.Status), some(HopCantRelayToSelf)) + + conn = await src.dial(srelay.peerInfo.peerId, srelay.peerInfo.addrs, RelayV1Codec) + msg = createMsg(some(RelayType.Hop), + none(StatusV1), + some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)), + some(RelayPeer(peerId: srelay.peerInfo.peerId, addrs: srelay.peerInfo.addrs))) + await conn.writeLp(msg.buffer) + rcv = RelayMessage.decode(await conn.readLp(1024)) + rcv.checkMsg(some(RelayType.Status), some(HopCantRelayToSelf)) + + conn = await src.dial(srelay.peerInfo.peerId, srelay.peerInfo.addrs, RelayV1Codec) + msg = createMsg(some(RelayType.Hop), + none(StatusV1), + some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)), + some(RelayPeer(peerId: dst.peerInfo.peerId, addrs: dst.peerInfo.addrs))) + await conn.writeLp(msg.buffer) + rcv = RelayMessage.decode(await conn.readLp(1024)) + rcv.checkMsg(some(RelayType.Status), some(HopNoConnToDst)) + + await srelay.connect(dst.peerInfo.peerId, dst.peerInfo.addrs) + + var tmp = r.maxCircuit + r.maxCircuit = 0 + conn = await src.dial(srelay.peerInfo.peerId, srelay.peerInfo.addrs, RelayV1Codec) + await conn.writeLp(msg.buffer) + rcv = RelayMessage.decode(await conn.readLp(1024)) + rcv.checkMsg(some(RelayType.Status), some(HopCantSpeakRelay)) + r.maxCircuit = tmp + await conn.close() + + tmp = r.maxCircuitPerPeer + r.maxCircuitPerPeer = 0 + conn = await src.dial(srelay.peerInfo.peerId, srelay.peerInfo.addrs, RelayV1Codec) + await conn.writeLp(msg.buffer) + rcv = RelayMessage.decode(await conn.readLp(1024)) + rcv.checkMsg(some(RelayType.Status), some(HopCantSpeakRelay)) + r.maxCircuitPerPeer = tmp + await conn.close() + + let dst2 = newStandardSwitch() + await dst2.start() + await srelay.connect(dst2.peerInfo.peerId, dst2.peerInfo.addrs) + + conn = await src.dial(srelay.peerInfo.peerId, srelay.peerInfo.addrs, RelayV1Codec) + msg = createMsg(some(RelayType.Hop), + none(StatusV1), + some(RelayPeer(peerId: src.peerInfo.peerId, addrs: src.peerInfo.addrs)), + some(RelayPeer(peerId: dst2.peerInfo.peerId, addrs: dst2.peerInfo.addrs))) + await conn.writeLp(msg.buffer) + rcv = RelayMessage.decode(await conn.readLp(1024)) + rcv.checkMsg(some(RelayType.Status), some(HopCantDialDst)) + await allFutures(dst2.stop()) + + asyncTest "Dial Peer": + let maStr = $srelay.peerInfo.addrs[0] & "/p2p/" & $srelay.peerInfo.peerId & "/p2p-circuit/p2p/" & $dst.peerInfo.peerId + let maddr = MultiAddress.init(maStr).tryGet() + await src.connect(srelay.peerInfo.peerId, srelay.peerInfo.addrs) + await srelay.connect(dst.peerInfo.peerId, dst.peerInfo.addrs) + conn = await src.dial(dst.peerInfo.peerId, @[ maddr ], protos[0]) + + await conn.writeLp("line1") + check string.fromBytes(await conn.readLp(1024)) == "line2" + + await conn.writeLp("line3") + check string.fromBytes(await conn.readLp(1024)) == "line4" + + asyncTest "Bad MultiAddress": + await src.connect(srelay.peerInfo.peerId, srelay.peerInfo.addrs) + await srelay.connect(dst.peerInfo.peerId, dst.peerInfo.addrs) + expect(CatchableError): + let maStr = $srelay.peerInfo.addrs[0] & "/p2p/" & $srelay.peerInfo.peerId & "/p2p/" & $dst.peerInfo.peerId + let maddr = MultiAddress.init(maStr).tryGet() + conn = await src.dial(dst.peerInfo.peerId, @[ maddr ], protos[0]) + + expect(CatchableError): + let maStr = $srelay.peerInfo.addrs[0] & "/p2p/" & $srelay.peerInfo.peerId + let maddr = MultiAddress.init(maStr).tryGet() + conn = await src.dial(dst.peerInfo.peerId, @[ maddr ], protos[0]) + + expect(CatchableError): + let maStr = "/ip4/127.0.0.1" + let maddr = MultiAddress.init(maStr).tryGet() + conn = await src.dial(dst.peerInfo.peerId, @[ maddr ], protos[0]) + + expect(CatchableError): + let maStr = $dst.peerInfo.peerId + let maddr = MultiAddress.init(maStr).tryGet() + conn = await src.dial(dst.peerInfo.peerId, @[ maddr ], protos[0]) diff --git a/tests/testrelayv2.nim b/tests/testrelayv2.nim new file mode 100644 index 000000000..5ce1f640a --- /dev/null +++ b/tests/testrelayv2.nim @@ -0,0 +1,412 @@ +{.used.} + +import bearssl, chronos, options +import ../libp2p +import ../libp2p/[protocols/relay/relay, + protocols/relay/messages, + protocols/relay/utils, + protocols/relay/client] +import ./helpers +import std/times +import stew/byteutils + +proc createSwitch(r: Relay): Switch = + result = SwitchBuilder.new() + .withRng(newRng()) + .withAddresses(@[ MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet() ]) + .withTcpTransport() + .withMplex() + .withNoise() + .withCircuitRelay(r) + .build() + +suite "Circuit Relay V2": + + suite "Reservation": + asyncTeardown: + await allFutures(src1.stop(), src2.stop(), rel.stop()) + checkTrackers() + var + ttl {.threadvar.}: int + ldur {.threadvar.}: uint32 + ldata {.threadvar.}: uint64 + cl1 {.threadvar.}: RelayClient + cl2 {.threadvar.}: RelayClient + rv2 {.threadvar.}: Relay + src1 {.threadvar.}: Switch + src2 {.threadvar.}: Switch + rel {.threadvar.}: Switch + rsvp {.threadvar.}: Rsvp + range {.threadvar.}: HSlice[times.DateTime, times.DateTime] + + asyncSetup: + ttl = 1 + ldur = 60 + ldata = 2048 + cl1 = RelayClient.new() + cl2 = RelayClient.new() + rv2 = Relay.new(reservationTTL=initDuration(seconds=ttl), + limitDuration=ldur, + limitData=ldata, + maxCircuit=1) + src1 = createSwitch(cl1) + src2 = createSwitch(cl2) + rel = createSwitch(rv2) + + await src1.start() + await src2.start() + await rel.start() + await src1.connect(rel.peerInfo.peerId, rel.peerInfo.addrs) + await src2.connect(rel.peerInfo.peerId, rel.peerInfo.addrs) + rsvp = await cl1.reserve(rel.peerInfo.peerId, rel.peerInfo.addrs) + range = now().utc + (ttl-3).seconds..now().utc + (ttl+3).seconds + check: + rsvp.expire.int64.fromUnix.utc in range + rsvp.limitDuration == ldur + rsvp.limitData == ldata + + asyncTest "Too many reservations": + let conn = await cl2.switch.dial(rel.peerInfo.peerId, rel.peerInfo.addrs, RelayV2HopCodec) + let pb = encode(HopMessage(msgType: HopMessageType.Reserve)) + await conn.writeLp(pb.buffer) + let msg = HopMessage.decode(await conn.readLp(RelayMsgSize)).get() + check: + msg.msgType == HopMessageType.Status + msg.status == some(StatusV2.ReservationRefused) + + asyncTest "Too many reservations + Reconnect": + expect(ReservationError): + discard await cl2.reserve(rel.peerInfo.peerId, rel.peerInfo.addrs) + await rel.disconnect(src1.peerInfo.peerId) + rsvp = await cl2.reserve(rel.peerInfo.peerId, rel.peerInfo.addrs) + range = now().utc + (ttl-3).seconds..now().utc + (ttl+3).seconds + check: + rsvp.expire.int64.fromUnix.utc in range + rsvp.limitDuration == ldur + rsvp.limitData == ldata + + asyncTest "Reservation ttl expires": + await sleepAsync(chronos.timer.seconds(ttl + 1)) + rsvp = await cl1.reserve(rel.peerInfo.peerId, rel.peerInfo.addrs) + range = now().utc + (ttl-3).seconds..now().utc + (ttl+3).seconds + check: + rsvp.expire.int64.fromUnix.utc in range + rsvp.limitDuration == ldur + rsvp.limitData == ldata + + asyncTest "Reservation over relay": + let + rv2add = Relay.new() + addrs = @[ MultiAddress.init($rel.peerInfo.addrs[0] & "/p2p/" & + $rel.peerInfo.peerId & "/p2p-circuit/p2p/" & + $src2.peerInfo.peerId).get() ] + rv2add.setup(src2) + await rv2add.start() + src2.mount(rv2add) + rv2.maxCircuit.inc() + + rsvp = await cl2.reserve(rel.peerInfo.peerId, rel.peerInfo.addrs) + range = now().utc + (ttl-3).seconds..now().utc + (ttl+3).seconds + check: + rsvp.expire.int64.fromUnix.utc in range + rsvp.limitDuration == ldur + rsvp.limitData == ldata + expect(ReservationError): + discard await cl1.reserve(src2.peerInfo.peerId, addrs) + + suite "Connection": + asyncTeardown: + checkTrackers() + var + addrs {.threadvar.}: MultiAddress + customProtoCodec {.threadvar.}: string + proto {.threadvar.}: LPProtocol + ttl {.threadvar.}: int + ldur {.threadvar.}: uint32 + ldata {.threadvar.}: uint64 + srcCl {.threadvar.}: RelayClient + dstCl {.threadvar.}: RelayClient + rv2 {.threadvar.}: Relay + src {.threadvar.}: Switch + dst {.threadvar.}: Switch + rel {.threadvar.}: Switch + rsvp {.threadvar.}: Rsvp + conn {.threadvar.}: Connection + + asyncSetup: + customProtoCodec = "/test" + proto = new LPProtocol + proto.codec = customProtoCodec + ttl = 60 + ldur = 120 + ldata = 16384 + srcCl = RelayClient.new() + dstCl = RelayClient.new() + src = createSwitch(srcCl) + dst = createSwitch(dstCl) + rel = newStandardSwitch() + addrs = MultiAddress.init($rel.peerInfo.addrs[0] & "/p2p/" & + $rel.peerInfo.peerId & "/p2p-circuit/p2p/" & + $dst.peerInfo.peerId).get() + + asyncTest "Connection succeed": + proto.handler = proc(conn: Connection, proto: string) {.async.} = + check: "test1" == string.fromBytes(await conn.readLp(1024)) + await conn.writeLp("test2") + check: "test3" == string.fromBytes(await conn.readLp(1024)) + await conn.writeLp("test4") + await conn.close() + rv2 = Relay.new(reservationTTL=initDuration(seconds=ttl), + limitDuration=ldur, + limitData=ldata) + rv2.setup(rel) + rel.mount(rv2) + dst.mount(proto) + + await rel.start() + await src.start() + await dst.start() + + await src.connect(rel.peerInfo.peerId, rel.peerInfo.addrs) + await dst.connect(rel.peerInfo.peerId, rel.peerInfo.addrs) + + rsvp = await dstCl.reserve(rel.peerInfo.peerId, rel.peerInfo.addrs) + + conn = await src.dial(dst.peerInfo.peerId, @[ addrs ], customProtoCodec) + await conn.writeLp("test1") + check: "test2" == string.fromBytes(await conn.readLp(1024)) + await conn.writeLp("test3") + check: "test4" == string.fromBytes(await conn.readLp(1024)) + await allFutures(conn.close()) + await allFutures(src.stop(), dst.stop(), rel.stop()) + + asyncTest "Connection duration exceeded": + ldur = 2 + proto.handler = proc(conn: Connection, proto: string) {.async.} = + check "wanna sleep?" == string.fromBytes(await conn.readLp(1024)) + await conn.writeLp("yeah!") + check "go!" == string.fromBytes(await conn.readLp(1024)) + await sleepAsync(3000) + await conn.writeLp("that was a cool power nap") + await conn.close() + rv2 = Relay.new(reservationTTL=initDuration(seconds=ttl), + limitDuration=ldur, + limitData=ldata) + rv2.setup(rel) + rel.mount(rv2) + dst.mount(proto) + + await rel.start() + await src.start() + await dst.start() + + await src.connect(rel.peerInfo.peerId, rel.peerInfo.addrs) + await dst.connect(rel.peerInfo.peerId, rel.peerInfo.addrs) + + rsvp = await dstCl.reserve(rel.peerInfo.peerId, rel.peerInfo.addrs) + conn = await src.dial(dst.peerInfo.peerId, @[ addrs ], customProtoCodec) + await conn.writeLp("wanna sleep?") + check: "yeah!" == string.fromBytes(await conn.readLp(1024)) + await conn.writeLp("go!") + expect(LPStreamEOFError): + discard await conn.readLp(1024) + await allFutures(conn.close()) + await allFutures(src.stop(), dst.stop(), rel.stop()) + + asyncTest "Connection data exceeded": + ldata = 1000 + proto.handler = proc(conn: Connection, proto: string) {.async.} = + check "count me the better story you know" == string.fromBytes(await conn.readLp(1024)) + await conn.writeLp("do you expect a lorem ipsum or...?") + check "surprise me!" == string.fromBytes(await conn.readLp(1024)) + await conn.writeLp("""Call me Ishmael. Some years ago--never mind how long +precisely--having little or no money in my purse, and nothing +particular to interest me on shore, I thought I would sail about a +little and see the watery part of the world. It is a way I have of +driving off the spleen and regulating the circulation. Whenever I +find myself growing grim about the mouth; whenever it is a damp, +drizzly November in my soul; whenever I find myself involuntarily +pausing before coffin warehouses, and bringing up the rear of every +funeral I meet; and especially whenever my hypos get such an upper +hand of me, that it requires a strong moral principle to prevent me +from deliberately stepping into the street, and methodically knocking +people's hats off--then, I account it high time to get to sea as soon +as I can. This is my substitute for pistol and ball. With a +philosophical flourish Cato throws himself upon his sword; I quietly +take to the ship.""") + rv2 = Relay.new(reservationTTL=initDuration(seconds=ttl), + limitDuration=ldur, + limitData=ldata) + rv2.setup(rel) + rel.mount(rv2) + dst.mount(proto) + + await rel.start() + await src.start() + await dst.start() + + await src.connect(rel.peerInfo.peerId, rel.peerInfo.addrs) + await dst.connect(rel.peerInfo.peerId, rel.peerInfo.addrs) + + rsvp = await dstCl.reserve(rel.peerInfo.peerId, rel.peerInfo.addrs) + conn = await src.dial(dst.peerInfo.peerId, @[ addrs ], customProtoCodec) + await conn.writeLp("count me the better story you know") + check: "do you expect a lorem ipsum or...?" == string.fromBytes(await conn.readLp(1024)) + await conn.writeLp("surprise me!") + expect(LPStreamEOFError): + discard await conn.readLp(1024) + await allFutures(conn.close()) + await allFutures(src.stop(), dst.stop(), rel.stop()) + + asyncTest "Reservation ttl expire during connection": + ttl = 1 + proto.handler = proc(conn: Connection, proto: string) {.async.} = + check: "test1" == string.fromBytes(await conn.readLp(1024)) + await conn.writeLp("test2") + check: "test3" == string.fromBytes(await conn.readLp(1024)) + await conn.writeLp("test4") + await conn.close() + rv2 = Relay.new(reservationTTL=initDuration(seconds=ttl), + limitDuration=ldur, + limitData=ldata) + rv2.setup(rel) + rel.mount(rv2) + dst.mount(proto) + + await rel.start() + await src.start() + await dst.start() + + await src.connect(rel.peerInfo.peerId, rel.peerInfo.addrs) + await dst.connect(rel.peerInfo.peerId, rel.peerInfo.addrs) + + rsvp = await dstCl.reserve(rel.peerInfo.peerId, rel.peerInfo.addrs) + conn = await src.dial(dst.peerInfo.peerId, @[ addrs ], customProtoCodec) + await conn.writeLp("test1") + check: "test2" == string.fromBytes(await conn.readLp(1024)) + await conn.writeLp("test3") + check: "test4" == string.fromBytes(await conn.readLp(1024)) + await src.disconnect(rel.peerInfo.peerId) + await sleepAsync(2000) + + expect(DialFailedError): + check: conn.atEof() + await conn.close() + await src.connect(rel.peerInfo.peerId, rel.peerInfo.addrs) + conn = await src.dial(dst.peerInfo.peerId, @[ addrs ], customProtoCodec) + await allFutures(conn.close()) + await allFutures(src.stop(), dst.stop(), rel.stop()) + + asyncTest "Connection over relay": + # src => rel => rel2 => dst + # rel2 reserve rel + # dst reserve rel2 + # src try to connect with dst + proto.handler = proc(conn: Connection, proto: string) {.async.} = + raise newException(CatchableError, "Should not be here") + let + rel2Cl = RelayClient.new(canHop = true) + rel2 = createSwitch(rel2Cl) + rv2 = Relay.new() + addrs = @[ MultiAddress.init($rel.peerInfo.addrs[0] & "/p2p/" & + $rel.peerInfo.peerId & "/p2p-circuit/p2p/" & + $rel2.peerInfo.peerId & "/p2p/" & + $rel2.peerInfo.peerId & "/p2p-circuit/p2p/" & + $dst.peerInfo.peerId).get() ] + rv2.setup(rel) + rel.mount(rv2) + dst.mount(proto) + await rel.start() + await rel2.start() + await src.start() + await dst.start() + + await src.connect(rel.peerInfo.peerId, rel.peerInfo.addrs) + await rel2.connect(rel.peerInfo.peerId, rel.peerInfo.addrs) + await dst.connect(rel2.peerInfo.peerId, rel2.peerInfo.addrs) + + rsvp = await rel2Cl.reserve(rel.peerInfo.peerId, rel.peerInfo.addrs) + let rsvp2 = await dstCl.reserve(rel2.peerInfo.peerId, rel2.peerInfo.addrs) + + expect(DialFailedError): + conn = await src.dial(dst.peerInfo.peerId, addrs, customProtoCodec) + await allFutures(conn.close()) + await allFutures(src.stop(), dst.stop(), rel.stop(), rel2.stop()) + + asyncTest "Connection using ClientRelay": + var + protoABC = new LPProtocol + protoBCA = new LPProtocol + protoCAB = new LPProtocol + protoABC.codec = "/abctest" + protoABC.handler = proc(conn: Connection, proto: string) {.async.} = + check: "testABC1" == string.fromBytes(await conn.readLp(1024)) + await conn.writeLp("testABC2") + check: "testABC3" == string.fromBytes(await conn.readLp(1024)) + await conn.writeLp("testABC4") + await conn.close() + protoBCA.codec = "/bcatest" + protoBCA.handler = proc(conn: Connection, proto: string) {.async.} = + check: "testBCA1" == string.fromBytes(await conn.readLp(1024)) + await conn.writeLp("testBCA2") + check: "testBCA3" == string.fromBytes(await conn.readLp(1024)) + await conn.writeLp("testBCA4") + await conn.close() + protoCAB.codec = "/cabtest" + protoCAB.handler = proc(conn: Connection, proto: string) {.async.} = + check: "testCAB1" == string.fromBytes(await conn.readLp(1024)) + await conn.writeLp("testCAB2") + check: "testCAB3" == string.fromBytes(await conn.readLp(1024)) + await conn.writeLp("testCAB4") + await conn.close() + + let + clientA = RelayClient.new(canHop = true) + clientB = RelayClient.new(canHop = true) + clientC = RelayClient.new(canHop = true) + switchA = createSwitch(clientA) + switchB = createSwitch(clientB) + switchC = createSwitch(clientC) + addrsABC = MultiAddress.init($switchB.peerInfo.addrs[0] & "/p2p/" & + $switchB.peerInfo.peerId & "/p2p-circuit/p2p/" & + $switchC.peerInfo.peerId).get() + addrsBCA = MultiAddress.init($switchC.peerInfo.addrs[0] & "/p2p/" & + $switchC.peerInfo.peerId & "/p2p-circuit/p2p/" & + $switchA.peerInfo.peerId).get() + addrsCAB = MultiAddress.init($switchA.peerInfo.addrs[0] & "/p2p/" & + $switchA.peerInfo.peerId & "/p2p-circuit/p2p/" & + $switchB.peerInfo.peerId).get() + switchA.mount(protoBCA) + switchB.mount(protoCAB) + switchC.mount(protoABC) + + await switchA.start() + await switchB.start() + await switchC.start() + + await switchA.connect(switchB.peerInfo.peerId, switchB.peerInfo.addrs) + await switchB.connect(switchC.peerInfo.peerId, switchC.peerInfo.addrs) + await switchC.connect(switchA.peerInfo.peerId, switchA.peerInfo.addrs) + let rsvpABC = await clientA.reserve(switchC.peerInfo.peerId, switchC.peerInfo.addrs) + let rsvpBCA = await clientB.reserve(switchA.peerInfo.peerId, switchA.peerInfo.addrs) + let rsvpCAB = await clientC.reserve(switchB.peerInfo.peerId, switchB.peerInfo.addrs) + let connABC = await switchA.dial(switchC.peerInfo.peerId, @[ addrsABC ], "/abctest") + let connBCA = await switchB.dial(switchA.peerInfo.peerId, @[ addrsBCA ], "/bcatest") + let connCAB = await switchC.dial(switchB.peerInfo.peerId, @[ addrsCAB ], "/cabtest") + + await connABC.writeLp("testABC1") + await connBCA.writeLp("testBCA1") + await connCAB.writeLp("testCAB1") + check: + "testABC2" == string.fromBytes(await connABC.readLp(1024)) + "testBCA2" == string.fromBytes(await connBCA.readLp(1024)) + "testCAB2" == string.fromBytes(await connCAB.readLp(1024)) + await connABC.writeLp("testABC3") + await connBCA.writeLp("testBCA3") + await connCAB.writeLp("testCAB3") + check: + "testABC4" == string.fromBytes(await connABC.readLp(1024)) + "testBCA4" == string.fromBytes(await connBCA.readLp(1024)) + "testCAB4" == string.fromBytes(await connCAB.readLp(1024)) + await allFutures(connABC.close(), connBCA.close(), connCAB.close()) + await allFutures(switchA.stop(), switchB.stop(), switchC.stop())