diff --git a/libp2p/dial.nim b/libp2p/dial.nim index e0d78b906..60d63e35a 100644 --- a/libp2p/dial.nim +++ b/libp2p/dial.nim @@ -36,7 +36,8 @@ method connect*( method connect*( self: Dial, - addrs: seq[MultiAddress]): Future[PeerId] {.async, base.} = + address: MultiAddress, + allowUnknownPeerId = false): Future[PeerId] {.async, base.} = ## Connects to a peer and retrieve its PeerId doAssert(false, "Not implemented!") diff --git a/libp2p/dialer.nim b/libp2p/dialer.nim index d3c0826dd..4e39733a4 100644 --- a/libp2p/dialer.nim +++ b/libp2p/dialer.nim @@ -219,11 +219,23 @@ method connect*( method connect*( self: Dialer, - addrs: seq[MultiAddress], - ): Future[PeerId] {.async.} = + address: MultiAddress, + allowUnknownPeerId = false): Future[PeerId] {.async.} = ## Connects to a peer and retrieve its PeerId - return (await self.internalConnect(Opt.none(PeerId), addrs, false)).peerId + let fullAddress = parseFullAddress(address) + if fullAddress.isOk: + return (await self.internalConnect( + Opt.some(fullAddress.get()[0]), + @[fullAddress.get()[1]], + false)).peerId + else: + if allowUnknownPeerId == false: + raise newException(DialFailedError, "Address without PeerID and unknown peer id disabled!") + return (await self.internalConnect( + Opt.none(PeerId), + @[address], + false)).peerId proc negotiateStream( self: Dialer, diff --git a/libp2p/switch.nim b/libp2p/switch.nim index b1a7d489e..85fe583ef 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -130,10 +130,15 @@ method connect*( method connect*( s: Switch, - addrs: seq[MultiAddress]): Future[PeerId] = + address: MultiAddress, + allowUnknownPeerId = false): Future[PeerId] = ## Connects to a peer and retrieve its PeerId + ## + ## If the P2P part is missing from the MA and `allowUnknownPeerId` is set + ## to true, this will discover the PeerId while connecting. This exposes + ## you to MiTM attacks, so it shouldn't be used without care! - s.dialer.connect(addrs) + s.dialer.connect(address, allowUnknownPeerId) method dial*( s: Switch, diff --git a/tests/testswitch.nim b/tests/testswitch.nim index 3cbb66d7e..92a7d9f53 100644 --- a/tests/testswitch.nim +++ b/tests/testswitch.nim @@ -10,6 +10,7 @@ import ../libp2p/[errors, builders, stream/bufferstream, stream/connection, + multicodec, multiaddress, peerinfo, crypto/crypto, @@ -213,12 +214,40 @@ suite "Switch": "dnsaddr=" & $switch1.peerInfo.addrs[0] & "/p2p/" & $switch1.peerInfo.peerId, ] - check: (await switch2.connect(@[MultiAddress.init("/dnsaddr/test.io/").tryGet()])) == switch1.peerInfo.peerId + check: (await switch2.connect(MultiAddress.init("/dnsaddr/test.io/").tryGet(), true)) == switch1.peerInfo.peerId await switch2.disconnect(switch1.peerInfo.peerId) # via direct ip check not switch2.isConnected(switch1.peerInfo.peerId) - check: (await switch2.connect(switch1.peerInfo.addrs)) == switch1.peerInfo.peerId + check: (await switch2.connect(switch1.peerInfo.addrs[0], true)) == switch1.peerInfo.peerId + + await switch2.disconnect(switch1.peerInfo.peerId) + + await allFuturesThrowing( + switch1.stop(), + switch2.stop() + ) + + asyncTest "e2e connect to peer with known PeerId": + let switch1 = newStandardSwitch(secureManagers = [SecureProtocol.Noise]) + let switch2 = newStandardSwitch(secureManagers = [SecureProtocol.Noise]) + await switch1.start() + await switch2.start() + + # via direct ip + check not switch2.isConnected(switch1.peerInfo.peerId) + + # without specifying allow unknown, will fail + expect(DialFailedError): + discard await switch2.connect(switch1.peerInfo.addrs[0]) + + # with invalid PeerId, will fail + let fakeMa = concat(switch1.peerInfo.addrs[0], MultiAddress.init(multiCodec("p2p"), PeerId.random.tryGet().data).tryGet()).tryGet() + expect(CatchableError): + discard (await switch2.connect(fakeMa)) + + # real thing works + check (await switch2.connect(switch1.peerInfo.fullAddrs.tryGet()[0])) == switch1.peerInfo.peerId await switch2.disconnect(switch1.peerInfo.peerId)