diff --git a/libp2p/dial.nim b/libp2p/dial.nim index bb8d00cd2..e0d78b906 100644 --- a/libp2p/dial.nim +++ b/libp2p/dial.nim @@ -13,10 +13,13 @@ else: {.push raises: [].} import chronos +import stew/results import peerid, stream/connection, transports/transport +export results + type Dial* = ref object of RootObj @@ -69,5 +72,5 @@ method addTransport*( method tryDial*( self: Dial, peerId: PeerId, - addrs: seq[MultiAddress]): Future[MultiAddress] {.async, base.} = + addrs: seq[MultiAddress]): Future[Opt[MultiAddress]] {.async, base.} = doAssert(false, "Not implemented!") diff --git a/libp2p/dialer.nim b/libp2p/dialer.nim index 85c5b63d9..30ff15e27 100644 --- a/libp2p/dialer.nim +++ b/libp2p/dialer.nim @@ -9,6 +9,7 @@ import std/[sugar, tables] +import stew/results import pkg/[chronos, chronicles, metrics] @@ -24,7 +25,7 @@ import dial, upgrademngrs/upgrade, errors -export dial, errors +export dial, errors, results logScope: topics = "libp2p dialer" @@ -189,7 +190,7 @@ proc negotiateStream( method tryDial*( self: Dialer, peerId: PeerId, - addrs: seq[MultiAddress]): Future[MultiAddress] {.async.} = + addrs: seq[MultiAddress]): Future[Opt[MultiAddress]] {.async.} = ## Create a protocol stream in order to check ## if a connection is possible. ## Doesn't use the Connection Manager to save it. diff --git a/libp2p/protocols/connectivity/autonat.nim b/libp2p/protocols/connectivity/autonat.nim index 4f7fb53eb..6408a7e04 100644 --- a/libp2p/protocols/connectivity/autonat.nim +++ b/libp2p/protocols/connectivity/autonat.nim @@ -13,6 +13,7 @@ else: {.push raises: [].} import std/[options, sets, sequtils] +import stew/results import chronos, chronicles, stew/objects import ../protocol, ../../switch, @@ -226,7 +227,10 @@ proc tryDial(a: Autonat, conn: Connection, addrs: seq[MultiAddress]) {.async.} = try: await a.sem.acquire() let ma = await a.switch.dialer.tryDial(conn.peerId, addrs) - await conn.sendResponseOk(ma) + if ma.isSome: + await conn.sendResponseOk(ma.get()) + else: + await conn.sendResponseError(DialError, "Missing observed address") except CancelledError as exc: raise exc except CatchableError as exc: @@ -241,15 +245,19 @@ proc handleDial(a: Autonat, conn: Connection, msg: AutonatMsg): Future[void] = if peerInfo.id.isSome() and peerInfo.id.get() != conn.peerId: return conn.sendResponseError(BadRequest, "PeerId mismatch") - var isRelayed = conn.observedAddr.contains(multiCodec("p2p-circuit")) + if conn.observedAddr.isNone: + return conn.sendResponseError(BadRequest, "Missing observed address") + let observedAddr = conn.observedAddr.get() + + var isRelayed = observedAddr.contains(multiCodec("p2p-circuit")) if isRelayed.isErr() or isRelayed.get(): return conn.sendResponseError(DialRefused, "Refused to dial a relayed observed address") - let hostIp = conn.observedAddr[0] + let hostIp = observedAddr[0] if hostIp.isErr() or not IP.match(hostIp.get()): - trace "wrong observed address", address=conn.observedAddr + trace "wrong observed address", address=observedAddr return conn.sendResponseError(InternalError, "Expected an IP address") var addrs = initHashSet[MultiAddress]() - addrs.incl(conn.observedAddr) + addrs.incl(observedAddr) for ma in peerInfo.addrs: isRelayed = ma.contains(multiCodec("p2p-circuit")) if isRelayed.isErr() or isRelayed.get(): diff --git a/libp2p/protocols/identify.nim b/libp2p/protocols/identify.nim index 975126ff8..ae7d26c48 100644 --- a/libp2p/protocols/identify.nim +++ b/libp2p/protocols/identify.nim @@ -16,6 +16,7 @@ else: {.push raises: [].} import std/[sequtils, options, strutils, sugar] +import stew/results import chronos, chronicles import ../protobuf/minprotobuf, ../peerinfo, @@ -80,7 +81,7 @@ chronicles.expandIt(IdentifyInfo): if iinfo.signedPeerRecord.isSome(): "Some" else: "None" -proc encodeMsg(peerInfo: PeerInfo, observedAddr: MultiAddress, sendSpr: bool): ProtoBuffer +proc encodeMsg(peerInfo: PeerInfo, observedAddr: Opt[MultiAddress], sendSpr: bool): ProtoBuffer {.raises: [Defect].} = result = initProtoBuffer() @@ -91,7 +92,8 @@ proc encodeMsg(peerInfo: PeerInfo, observedAddr: MultiAddress, sendSpr: bool): P result.write(2, ma.data.buffer) for proto in peerInfo.protocols: result.write(3, proto) - result.write(4, observedAddr.data.buffer) + if observedAddr.isSome: + result.write(4, observedAddr.get().data.buffer) let protoVersion = ProtoVersion result.write(5, protoVersion) let agentVersion = if peerInfo.agentVersion.len <= 0: diff --git a/libp2p/protocols/pubsub/pubsubpeer.nim b/libp2p/protocols/pubsub/pubsubpeer.nim index 8af49de8a..5a1afedbd 100644 --- a/libp2p/protocols/pubsub/pubsubpeer.nim +++ b/libp2p/protocols/pubsub/pubsubpeer.nim @@ -12,7 +12,8 @@ when (NimMajor, NimMinor) < (1, 4): else: {.push raises: [].} -import std/[sequtils, strutils, tables, hashes] +import std/[sequtils, strutils, tables, hashes, options] +import stew/results import chronos, chronicles, nimcrypto/sha2, metrics import rpc/[messages, message, protobuf], ../../peerid, @@ -174,7 +175,7 @@ proc connectOnce(p: PubSubPeer): Future[void] {.async.} = trace "Get new send connection", p, newConn p.sendConn = newConn - p.address = some(p.sendConn.observedAddr) + p.address = if p.sendConn.observedAddr.isSome: some(p.sendConn.observedAddr.get) else: none(MultiAddress) if p.onEvent != nil: p.onEvent(p, PubSubPeerEvent(kind: PubSubPeerEventKind.Connected)) diff --git a/libp2p/protocols/secure/secure.nim b/libp2p/protocols/secure/secure.nim index 0bdf85209..c187fd2fb 100644 --- a/libp2p/protocols/secure/secure.nim +++ b/libp2p/protocols/secure/secure.nim @@ -13,6 +13,7 @@ else: {.push raises: [].} import std/[strformat] +import stew/results import chronos, chronicles import ../protocol, ../../stream/streamseq, @@ -21,7 +22,7 @@ import ../protocol, ../../peerinfo, ../../errors -export protocol +export protocol, results logScope: topics = "libp2p secure" @@ -48,7 +49,7 @@ chronicles.formatIt(SecureConn): shortLog(it) proc new*(T: type SecureConn, conn: Connection, peerId: PeerId, - observedAddr: MultiAddress, + observedAddr: Opt[MultiAddress], timeout: Duration = DefaultConnectionTimeout): T = result = T(stream: conn, peerId: peerId, diff --git a/libp2p/stream/chronosstream.nim b/libp2p/stream/chronosstream.nim index 2b676d468..d123c98be 100644 --- a/libp2p/stream/chronosstream.nim +++ b/libp2p/stream/chronosstream.nim @@ -13,10 +13,13 @@ else: {.push raises: [].} import std/[oids, strformat] +import stew/results import chronos, chronicles, metrics import connection import ../utility +export results + logScope: topics = "libp2p chronosstream" @@ -60,7 +63,7 @@ proc init*(C: type ChronosStream, client: StreamTransport, dir: Direction, timeout = DefaultChronosStreamTimeout, - observedAddr: MultiAddress = MultiAddress()): ChronosStream = + observedAddr: Opt[MultiAddress]): ChronosStream = result = C(client: client, timeout: timeout, dir: dir, diff --git a/libp2p/stream/connection.nim b/libp2p/stream/connection.nim index 43f58af66..a4f52deae 100644 --- a/libp2p/stream/connection.nim +++ b/libp2p/stream/connection.nim @@ -13,13 +13,14 @@ else: {.push raises: [].} import std/[hashes, oids, strformat] +import stew/results import chronicles, chronos, metrics import lpstream, ../multiaddress, ../peerinfo, ../errors -export lpstream, peerinfo, errors +export lpstream, peerinfo, errors, results logScope: topics = "libp2p connection" @@ -37,7 +38,7 @@ type timerTaskFut: Future[void] # the current timer instance timeoutHandler*: TimeoutHandler # timeout handler peerId*: PeerId - observedAddr*: MultiAddress + observedAddr*: Opt[MultiAddress] upgraded*: Future[void] protocol*: string # protocol used by the connection, used as tag for metrics transportDir*: Direction # The bottom level transport (generally the socket) direction @@ -160,9 +161,9 @@ method getWrapped*(s: Connection): Connection {.base.} = proc new*(C: type Connection, peerId: PeerId, dir: Direction, + observedAddr: Opt[MultiAddress], timeout: Duration = DefaultConnectionTimeout, - timeoutHandler: TimeoutHandler = nil, - observedAddr: MultiAddress = MultiAddress()): Connection = + timeoutHandler: TimeoutHandler = nil): Connection = result = C(peerId: peerId, dir: dir, timeout: timeout, diff --git a/libp2p/transports/tcptransport.nim b/libp2p/transports/tcptransport.nim index 24f63f6af..6d0d32182 100644 --- a/libp2p/transports/tcptransport.nim +++ b/libp2p/transports/tcptransport.nim @@ -15,6 +15,7 @@ else: {.push raises: [].} import std/[oids, sequtils] +import stew/results import chronos, chronicles import transport, ../errors, @@ -31,7 +32,7 @@ import transport, logScope: topics = "libp2p tcptransport" -export transport +export transport, results const TcpTransportTrackerName* = "libp2p.tcptransport" @@ -71,18 +72,20 @@ proc setupTcpTransportTracker(): TcpTransportTracker = result.isLeaked = leakTransport addTracker(TcpTransportTrackerName, result) -proc connHandler*(self: TcpTransport, - client: StreamTransport, - dir: Direction): Future[Connection] {.async.} = - var observedAddr: MultiAddress = MultiAddress() +proc getObservedAddr(client: StreamTransport): Future[MultiAddress] {.async.} = try: - observedAddr = MultiAddress.init(client.remoteAddress).tryGet() + return MultiAddress.init(client.remoteAddress).tryGet() except CatchableError as exc: trace "Failed to create observedAddr", exc = exc.msg if not(isNil(client) and client.closed): await client.closeWait() raise exc +proc connHandler*(self: TcpTransport, + client: StreamTransport, + observedAddr: Opt[MultiAddress], + dir: Direction): Future[Connection] {.async.} = + trace "Handling tcp connection", address = $observedAddr, dir = $dir, clients = self.clients[Direction.In].len + @@ -222,7 +225,8 @@ method accept*(self: TcpTransport): Future[Connection] {.async, gcsafe.} = self.acceptFuts[index] = self.servers[index].accept() let transp = await finished - return await self.connHandler(transp, Direction.In) + let observedAddr = await getObservedAddr(transp) + return await self.connHandler(transp, Opt.some(observedAddr), Direction.In) except TransportOsError as exc: # TODO: it doesn't sound like all OS errors # can be ignored, we should re-raise those @@ -250,7 +254,8 @@ method dial*( let transp = await connect(address) try: - return await self.connHandler(transp, Direction.Out) + let observedAddr = await getObservedAddr(transp) + return await self.connHandler(transp, Opt.some(observedAddr), Direction.Out) except CatchableError as err: await transp.closeWait() raise err diff --git a/libp2p/transports/wstransport.nim b/libp2p/transports/wstransport.nim index 50eb7c333..9ded1c1e7 100644 --- a/libp2p/transports/wstransport.nim +++ b/libp2p/transports/wstransport.nim @@ -15,6 +15,7 @@ else: {.push raises: [].} import std/[sequtils] +import stew/results import chronos, chronicles import transport, ../errors, @@ -31,7 +32,7 @@ import transport, logScope: topics = "libp2p wstransport" -export transport, websock +export transport, websock, results const WsTransportTrackerName* = "libp2p.wstransport" @@ -45,8 +46,8 @@ type proc new*(T: type WsStream, session: WSSession, dir: Direction, - timeout = 10.minutes, - observedAddr: MultiAddress = MultiAddress()): T = + observedAddr: Opt[MultiAddress], + timeout = 10.minutes): T = let stream = T( session: session, @@ -221,8 +222,7 @@ proc connHandler(self: WsTransport, await stream.close() raise exc - let conn = WsStream.new(stream, dir) - conn.observedAddr = observedAddr + let conn = WsStream.new(stream, dir, Opt.some(observedAddr)) self.connections[dir].add(conn) proc onClose() {.async.} = diff --git a/tests/commontransport.nim b/tests/commontransport.nim index 532689e04..af29add61 100644 --- a/tests/commontransport.nim +++ b/tests/commontransport.nim @@ -1,7 +1,7 @@ {.used.} import sequtils -import chronos, stew/byteutils +import chronos, stew/[byteutils, results] import ../libp2p/[stream/connection, transports/transport, upgrademngrs/upgrade, @@ -35,14 +35,16 @@ proc commonTransportTest*(name: string, prov: TransportProvider, ma: string) = proc acceptHandler() {.async, gcsafe.} = let conn = await transport1.accept() - check transport1.handles(conn.observedAddr) + if conn.observedAddr.isSome(): + check transport1.handles(conn.observedAddr.get()) await conn.close() let handlerWait = acceptHandler() let conn = await transport2.dial(transport1.addrs[0]) - check transport2.handles(conn.observedAddr) + if conn.observedAddr.isSome(): + check transport2.handles(conn.observedAddr.get()) await conn.close() #for some protocols, closing requires actively reading, so we must close here diff --git a/tests/testconnmngr.nim b/tests/testconnmngr.nim index 5c705bf36..dda2e3dc1 100644 --- a/tests/testconnmngr.nim +++ b/tests/testconnmngr.nim @@ -1,4 +1,5 @@ import sequtils +import stew/results import chronos import ../libp2p/[connmanager, stream/connection, @@ -9,6 +10,9 @@ import ../libp2p/[connmanager, import helpers +proc getConnection(peerId: PeerId, dir: Direction = Direction.In): Connection = + return Connection.new(peerId, dir, Opt.none(MultiAddress)) + type TestMuxer = ref object of Muxer peerId: PeerId @@ -18,7 +22,7 @@ method newStream*( name: string = "", lazy: bool = false): Future[Connection] {.async, gcsafe.} = - result = Connection.new(m.peerId, Direction.Out) + result = getConnection(m.peerId, Direction.Out) suite "Connection Manager": teardown: @@ -27,7 +31,7 @@ suite "Connection Manager": asyncTest "add and retrieve a connection": let connMngr = ConnManager.new() let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet() - let conn = Connection.new(peerId, Direction.In) + let conn = getConnection(peerId) connMngr.storeConn(conn) check conn in connMngr @@ -41,7 +45,7 @@ suite "Connection Manager": asyncTest "shouldn't allow a closed connection": let connMngr = ConnManager.new() let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet() - let conn = Connection.new(peerId, Direction.In) + let conn = getConnection(peerId) await conn.close() expect CatchableError: @@ -52,7 +56,7 @@ suite "Connection Manager": asyncTest "shouldn't allow an EOFed connection": let connMngr = ConnManager.new() let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet() - let conn = Connection.new(peerId, Direction.In) + let conn = getConnection(peerId) conn.isEof = true expect CatchableError: @@ -64,7 +68,7 @@ suite "Connection Manager": asyncTest "add and retrieve a muxer": let connMngr = ConnManager.new() let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet() - let conn = Connection.new(peerId, Direction.In) + let conn = getConnection(peerId) let muxer = new Muxer muxer.connection = conn @@ -80,7 +84,7 @@ suite "Connection Manager": asyncTest "shouldn't allow a muxer for an untracked connection": let connMngr = ConnManager.new() let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet() - let conn = Connection.new(peerId, Direction.In) + let conn = getConnection(peerId) let muxer = new Muxer muxer.connection = conn @@ -94,8 +98,8 @@ suite "Connection Manager": asyncTest "get conn with direction": let connMngr = ConnManager.new() let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet() - let conn1 = Connection.new(peerId, Direction.Out) - let conn2 = Connection.new(peerId, Direction.In) + let conn1 = getConnection(peerId, Direction.Out) + let conn2 = getConnection(peerId) connMngr.storeConn(conn1) connMngr.storeConn(conn2) @@ -114,7 +118,7 @@ suite "Connection Manager": asyncTest "get muxed stream for peer": let connMngr = ConnManager.new() let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet() - let conn = Connection.new(peerId, Direction.In) + let conn = getConnection(peerId) let muxer = new TestMuxer muxer.peerId = peerId @@ -134,7 +138,7 @@ suite "Connection Manager": asyncTest "get stream from directed connection": let connMngr = ConnManager.new() let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet() - let conn = Connection.new(peerId, Direction.In) + let conn = getConnection(peerId) let muxer = new TestMuxer muxer.peerId = peerId @@ -155,7 +159,7 @@ suite "Connection Manager": asyncTest "get stream from any connection": let connMngr = ConnManager.new() let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet() - let conn = Connection.new(peerId, Direction.In) + let conn = getConnection(peerId) let muxer = new TestMuxer muxer.peerId = peerId @@ -175,11 +179,11 @@ suite "Connection Manager": let connMngr = ConnManager.new(maxConnsPerPeer = 1) let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet() - connMngr.storeConn(Connection.new(peerId, Direction.In)) + connMngr.storeConn(getConnection(peerId)) let conns = @[ - Connection.new(peerId, Direction.In), - Connection.new(peerId, Direction.In)] + getConnection(peerId), + getConnection(peerId)] expect TooManyConnectionsError: connMngr.storeConn(conns[0]) @@ -193,7 +197,7 @@ suite "Connection Manager": asyncTest "cleanup on connection close": let connMngr = ConnManager.new() let peerId = PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet() - let conn = Connection.new(peerId, Direction.In) + let conn = getConnection(peerId) let muxer = new Muxer muxer.connection = conn @@ -220,7 +224,7 @@ suite "Connection Manager": Direction.In else: Direction.Out - let conn = Connection.new(peerId, dir) + let conn = getConnection(peerId, dir) let muxer = new Muxer muxer.connection = conn @@ -353,7 +357,7 @@ suite "Connection Manager": let slot = await ((connMngr.getOutgoingSlot()).wait(10.millis)) let conn = - Connection.new( + getConnection( PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(), Direction.In)