# Nim-LibP2P # Copyright (c) 2022 Status Research & Development GmbH # Licensed under either of # * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) # * MIT license ([LICENSE-MIT](LICENSE-MIT)) # at your option. # This file may not be copied, modified, or distributed except according to # those terms. {.push raises: [Defect].} import options import sequtils, strutils, tables import chronos, chronicles import ../peerinfo, ../switch, ../multiaddress, ../stream/connection, ../protocols/protocol, ../transports/transport, ../utility, ../errors const RelayCodec* = "/libp2p/circuit/relay/0.1.0" MsgSize* = 4096 MaxCircuit* = 1024 MaxCircuitPerPeer* = 64 logScope: topics = "libp2p relay" type RelayType* = enum Hop = 1 Stop = 2 Status = 3 CanHop = 4 RelayStatus* = enum Success = 100 HopSrcAddrTooLong = 220 HopDstAddrTooLong = 221 HopSrcMultiaddrInvalid = 250 HopDstMultiaddrInvalid = 251 HopNoConnToDst = 260 HopCantDialDst = 261 HopCantOpenDstStream = 262 HopCantSpeakRelay = 270 HopCantRelayToSelf = 280 StopSrcAddrTooLong = 320 StopDstAddrTooLong = 321 StopSrcMultiaddrInvalid = 350 StopDstMultiaddrInvalid = 351 StopRelayRefused = 390 MalformedMessage = 400 RelayError* = object of LPError RelayPeer* = object peerId*: PeerId addrs*: seq[MultiAddress] AddConn* = proc(conn: Connection): Future[void] {.gcsafe, raises: [Defect].} RelayMessage* = object msgType*: Option[RelayType] srcPeer*: Option[RelayPeer] dstPeer*: Option[RelayPeer] status*: Option[RelayStatus] Relay* = ref object of LPProtocol switch*: Switch peerId: PeerId dialer: Dial canHop: bool streamCount: int hopCount: CountTable[PeerId] addConn: AddConn maxCircuit*: int maxCircuitPerPeer*: int msgSize*: int proc encodeMsg*(msg: RelayMessage): ProtoBuffer = result = initProtoBuffer() if isSome(msg.msgType): result.write(1, msg.msgType.get().ord.uint) if isSome(msg.srcPeer): var peer = initProtoBuffer() peer.write(1, msg.srcPeer.get().peerId) for ma in msg.srcPeer.get().addrs: peer.write(2, ma.data.buffer) peer.finish() result.write(2, peer.buffer) if isSome(msg.dstPeer): var peer = initProtoBuffer() peer.write(1, msg.dstPeer.get().peerId) for ma in msg.dstPeer.get().addrs: peer.write(2, ma.data.buffer) peer.finish() result.write(3, peer.buffer) if isSome(msg.status): result.write(4, msg.status.get().ord.uint) result.finish() proc decodeMsg*(buf: seq[byte]): Option[RelayMessage] = var rMsg: RelayMessage msgTypeOrd: uint32 src: RelayPeer dst: RelayPeer statusOrd: uint32 pbSrc: ProtoBuffer pbDst: ProtoBuffer let pb = initProtoBuffer(buf) r1 = pb.getField(1, msgTypeOrd) r2 = pb.getField(2, pbSrc) r3 = pb.getField(3, pbDst) r4 = pb.getField(4, statusOrd) if r1.isErr() or r2.isErr() or r3.isErr() or r4.isErr(): return none(RelayMessage) if r2.get() and (pbSrc.getField(1, src.peerId).isErr() or pbSrc.getRepeatedField(2, src.addrs).isErr()): return none(RelayMessage) if r3.get() and (pbDst.getField(1, dst.peerId).isErr() or pbDst.getRepeatedField(2, dst.addrs).isErr()): return none(RelayMessage) if r1.get(): rMsg.msgType = some(RelayType(msgTypeOrd)) if r2.get(): rMsg.srcPeer = some(src) if r3.get(): rMsg.dstPeer = some(dst) if r4.get(): rMsg.status = some(RelayStatus(statusOrd)) some(rMsg) proc sendStatus*(conn: Connection, code: RelayStatus) {.async, gcsafe.} = trace "send status", status = $code & "(" & $ord(code) & ")" let msg = RelayMessage( msgType: some(RelayType.Status), status: some(code)) pb = encodeMsg(msg) await conn.writeLp(pb.buffer) proc handleHopStream(r: Relay, conn: Connection, msg: RelayMessage) {.async, gcsafe.} = r.streamCount.inc() defer: r.streamCount.dec() if r.streamCount > r.maxCircuit: trace "refusing connection; too many active circuit" await sendStatus(conn, RelayStatus.HopCantSpeakRelay) return proc checkMsg(): Result[RelayMessage, RelayStatus] = if not r.canHop: return err(RelayStatus.HopCantSpeakRelay) if msg.srcPeer.isNone: return err(RelayStatus.HopSrcMultiaddrInvalid) let src = msg.srcPeer.get() if src.peerId != conn.peerId: return err(RelayStatus.HopSrcMultiaddrInvalid) if msg.dstPeer.isNone: return err(RelayStatus.HopDstMultiaddrInvalid) let dst = msg.dstPeer.get() if dst.peerId == r.switch.peerInfo.peerId: return err(RelayStatus.HopCantRelayToSelf) if not r.switch.isConnected(dst.peerId): trace "relay not connected to dst", dst return err(RelayStatus.HopNoConnToDst) ok(msg) let check = checkMsg() if check.isErr: await sendStatus(conn, check.error()) return let src = msg.srcPeer.get() dst = msg.dstPeer.get() # TODO: if r.acl # access control list # and not r.acl.AllowHop(src.peerId, dst.peerId) # sendStatus(conn, RelayStatus.HopCantSpeakRelay) r.hopCount.inc(src.peerId) r.hopCount.inc(dst.peerId) defer: r.hopCount.inc(src.peerId, -1) r.hopCount.inc(dst.peerId, -1) if r.hopCount[src.peerId] > r.maxCircuitPerPeer: trace "refusing connection; too many connection from src", src, dst await sendStatus(conn, RelayStatus.HopCantSpeakRelay) return if r.hopCount[dst.peerId] > r.maxCircuitPerPeer: trace "refusing connection; too many connection to dst", src, dst await sendStatus(conn, RelayStatus.HopCantSpeakRelay) return let connDst = try: await r.switch.dial(dst.peerId, @[RelayCodec]) except CancelledError as exc: raise exc except CatchableError as exc: trace "error opening relay stream", dst, exc=exc.msg await sendStatus(conn, RelayStatus.HopCantDialDst) return defer: await connDst.close() let msgToSend = RelayMessage( msgType: some(RelayType.Stop), srcPeer: some(src), dstPeer: some(dst), status: none(RelayStatus)) let msgRcvFromDstOpt = try: await connDst.writeLp(encodeMsg(msgToSend).buffer) decodeMsg(await connDst.readLp(r.msgSize)) except CancelledError as exc: raise exc except CatchableError as exc: trace "error writing stop handshake or reading stop response", exc=exc.msg await sendStatus(conn, RelayStatus.HopCantOpenDstStream) return if msgRcvFromDstOpt.isNone: trace "error reading stop response", msg = msgRcvFromDstOpt await sendStatus(conn, RelayStatus.HopCantOpenDstStream) return let msgRcvFromDst = msgRcvFromDstOpt.get() if msgRcvFromDst.msgType.isNone or msgRcvFromDst.msgType.get() != RelayType.Status: trace "unexcepted relay stop response", msgType = msgRcvFromDst.msgType await sendStatus(conn, RelayStatus.HopCantOpenDstStream) return if msgRcvFromDst.status.isNone or msgRcvFromDst.status.get() != RelayStatus.Success: trace "relay stop failure", status=msgRcvFromDst.status await sendStatus(conn, RelayStatus.HopCantOpenDstStream) return await sendStatus(conn, RelayStatus.Success) trace "relaying connection", src, dst proc bridge(conn: Connection, connDst: Connection) {.async.} = const bufferSize = 4096 var bufSrcToDst: array[bufferSize, byte] bufDstToSrc: array[bufferSize, byte] futSrc = conn.readOnce(addr bufSrcToDst[0], bufSrcToDst.high + 1) futDst = connDst.readOnce(addr bufDstToSrc[0], bufDstToSrc.high + 1) bytesSendFromSrcToDst = 0 bytesSendFromDstToSrc = 0 bufRead: int while not conn.closed() and not connDst.closed(): try: await futSrc or futDst if futSrc.finished(): bufRead = await futSrc bytesSendFromSrcToDst.inc(bufRead) await connDst.write(@bufSrcToDst[0..= 3: result = CircuitRelay.match(sma[^2].get()) and P2PPattern.match(sma[^1].get()) trace "Handles return", ma, result proc new*(T: typedesc[RelayTransport], relay: Relay, upgrader: Upgrade): T = result = T(relay: relay, upgrader: upgrader) result.running = true result.queue = newAsyncQueue[Connection](0)