diff --git a/libp2p/crypto/curve25519.nim b/libp2p/crypto/curve25519.nim index d4b476bc7..98a80d7e2 100644 --- a/libp2p/crypto/curve25519.nim +++ b/libp2p/crypto/curve25519.nim @@ -31,7 +31,6 @@ const type Curve25519* = object Curve25519Key* = array[Curve25519KeySize, byte] - pcuchar = ptr char Curve25519Error* = enum Curver25519GenError @@ -77,7 +76,7 @@ proc mulgen(_: type[Curve25519], dst: var Curve25519Key, point: Curve25519Key) = addr rpoint[0], Curve25519KeySize, EC_curve25519) - + assert size == Curve25519KeySize proc public*(private: Curve25519Key): Curve25519Key = diff --git a/libp2p/dial.nim b/libp2p/dial.nim index bfe1620f5..bb8d00cd2 100644 --- a/libp2p/dial.nim +++ b/libp2p/dial.nim @@ -31,6 +31,13 @@ method connect*( doAssert(false, "Not implemented!") +method connect*( + self: Dial, + addrs: seq[MultiAddress]): Future[PeerId] {.async, base.} = + ## Connects to a peer and retrieve its PeerId + + doAssert(false, "Not implemented!") + method dial*( self: Dial, peerId: PeerId, diff --git a/libp2p/dialer.nim b/libp2p/dialer.nim index 532ca3617..85c5b63d9 100644 --- a/libp2p/dialer.nim +++ b/libp2p/dialer.nim @@ -47,7 +47,7 @@ type proc dialAndUpgrade( self: Dialer, - peerId: PeerId, + peerId: Opt[PeerId], addrs: seq[MultiAddress]): Future[Connection] {.async.} = debug "Dialing peer", peerId @@ -74,9 +74,6 @@ proc dialAndUpgrade( libp2p_failed_dials.inc() continue # Try the next address - # make sure to assign the peer to the connection - dialed.peerId = peerId - # also keep track of the connection's bottom unsafe transport direction # required by gossipsub scoring dialed.transportDir = Direction.Out @@ -84,7 +81,7 @@ proc dialAndUpgrade( libp2p_successful_dials.inc() let conn = try: - await transport.upgradeOutgoing(dialed) + await transport.upgradeOutgoing(dialed, peerId) except CatchableError as exc: # If we failed to establish the connection through one transport, # we won't succeeded through another - no use in trying again @@ -101,20 +98,22 @@ proc dialAndUpgrade( proc internalConnect( self: Dialer, - peerId: PeerId, + peerId: Opt[PeerId], addrs: seq[MultiAddress], forceDial: bool): Future[Connection] {.async.} = - if self.localPeerId == peerId: + if Opt.some(self.localPeerId) == peerId: raise newException(CatchableError, "can't dial self!") # Ensure there's only one in-flight attempt per peer - let lock = self.dialLock.mgetOrPut(peerId, newAsyncLock()) + let lock = self.dialLock.mgetOrPut(peerId.get(default(PeerId)), newAsyncLock()) try: await lock.acquire() # Check if we have a connection already and try to reuse it - var conn = self.connManager.selectConn(peerId) + var conn = + if peerId.isSome: self.connManager.selectConn(peerId.get()) + else: nil if conn != nil: if conn.atEof or conn.closed: # This connection should already have been removed from the connection @@ -165,7 +164,15 @@ method connect*( if self.connManager.connCount(peerId) > 0: return - discard await self.internalConnect(peerId, addrs, forceDial) + discard await self.internalConnect(Opt.some(peerId), addrs, forceDial) + +method connect*( + self: Dialer, + addrs: seq[MultiAddress], + ): Future[PeerId] {.async.} = + ## Connects to a peer and retrieve its PeerId + + return (await self.internalConnect(Opt.none(PeerId), addrs, false)).peerId proc negotiateStream( self: Dialer, @@ -190,7 +197,7 @@ method tryDial*( trace "Check if it can dial", peerId, addrs try: - let conn = await self.dialAndUpgrade(peerId, addrs) + let conn = await self.dialAndUpgrade(Opt.some(peerId), addrs) if conn.isNil(): raise newException(DialFailedError, "No valid multiaddress") await conn.close() @@ -238,7 +245,7 @@ method dial*( try: trace "Dialing (new)", peerId, protos - conn = await self.internalConnect(peerId, addrs, forceDial) + conn = await self.internalConnect(Opt.some(peerId), addrs, forceDial) trace "Opening stream", conn stream = await self.connManager.getStream(conn) diff --git a/libp2p/peerid.nim b/libp2p/peerid.nim index dc7a4aee0..401fee96d 100644 --- a/libp2p/peerid.nim +++ b/libp2p/peerid.nim @@ -43,7 +43,7 @@ func shortLog*(pid: PeerId): string = var spid = $pid if len(spid) > 10: spid[3] = '*' - + when (NimMajor, NimMinor) > (1, 4): spid.delete(4 .. spid.high - 6) else: diff --git a/libp2p/protocols/secure/noise.nim b/libp2p/protocols/secure/noise.nim index 6e613ea42..66a8dcd22 100644 --- a/libp2p/protocols/secure/noise.nim +++ b/libp2p/protocols/secure/noise.nim @@ -38,7 +38,7 @@ const # https://godoc.org/github.com/libp2p/go-libp2p-noise#pkg-constants NoiseCodec* = "/noise" - PayloadString = "noise-libp2p-static-key:" + PayloadString = toBytes("noise-libp2p-static-key:") ProtocolXXName = "Noise_XX_25519_ChaChaPoly_SHA256" @@ -339,7 +339,6 @@ proc handshakeXXOutbound( hs = HandshakeState.init() try: - hs.ss.mixHash(p.commonPrologue) hs.s = p.noiseKeys @@ -445,7 +444,6 @@ method readMessage*(sconn: NoiseConnection): Future[seq[byte]] {.async.} = dumpMessage(sconn, FlowDirection.Incoming, []) trace "Received 0-length message", sconn - proc encryptFrame( sconn: NoiseConnection, cipherFrame: var openArray[byte], @@ -506,7 +504,7 @@ method write*(sconn: NoiseConnection, message: seq[byte]): Future[void] = # sequencing issues sconn.stream.write(cipherFrames) -method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureConn] {.async.} = +method handshake*(p: Noise, conn: Connection, initiator: bool, peerId: Opt[PeerId]): Future[SecureConn] {.async.} = trace "Starting Noise handshake", conn, initiator let timeout = conn.timeout @@ -515,7 +513,7 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon # https://github.com/libp2p/specs/tree/master/noise#libp2p-data-in-handshake-messages let signedPayload = p.localPrivateKey.sign( - PayloadString.toBytes & p.noiseKeys.publicKey.getBytes).tryGet() + PayloadString & p.noiseKeys.publicKey.getBytes).tryGet() var libp2pProof = initProtoBuffer() @@ -538,11 +536,9 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon remoteSig: Signature remoteSigBytes: seq[byte] - let r1 = remoteProof.getField(1, remotePubKeyBytes) - let r2 = remoteProof.getField(2, remoteSigBytes) - if r1.isErr() or not(r1.get()): + if not remoteProof.getField(1, remotePubKeyBytes).valueOr(false): raise newException(NoiseHandshakeError, "Failed to deserialize remote public key bytes. (initiator: " & $initiator & ")") - if r2.isErr() or not(r2.get()): + if not remoteProof.getField(2, remoteSigBytes).valueOr(false): raise newException(NoiseHandshakeError, "Failed to deserialize remote signature bytes. (initiator: " & $initiator & ")") if not remotePubKey.init(remotePubKeyBytes): @@ -550,33 +546,34 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon if not remoteSig.init(remoteSigBytes): raise newException(NoiseHandshakeError, "Failed to decode remote signature. (initiator: " & $initiator & ")") - let verifyPayload = PayloadString.toBytes & handshakeRes.rs.getBytes + let verifyPayload = PayloadString & handshakeRes.rs.getBytes if not remoteSig.verify(verifyPayload, remotePubKey): raise newException(NoiseHandshakeError, "Noise handshake signature verify failed.") else: trace "Remote signature verified", conn - if initiator: - let pid = PeerId.init(remotePubKey) - if not conn.peerId.validate(): - raise newException(NoiseHandshakeError, "Failed to validate peerId.") - if pid.isErr or pid.get() != conn.peerId: + let pid = PeerId.init(remotePubKey).valueOr: + raise newException(NoiseHandshakeError, "Invalid remote peer id: " & $error) + + trace "Remote peer id", pid = $pid + + if peerId.isSome(): + let targetPid = peerId.get() + if not targetPid.validate(): + raise newException(NoiseHandshakeError, "Failed to validate expected peerId.") + + if pid != targetPid: var failedKey: PublicKey - discard extractPublicKey(conn.peerId, failedKey) - debug "Noise handshake, peer infos don't match!", + discard extractPublicKey(targetPid, failedKey) + debug "Noise handshake, peer id doesn't match!", initiator, dealt_peer = conn, dealt_key = $failedKey, received_peer = $pid, received_key = $remotePubKey - raise newException(NoiseHandshakeError, "Noise handshake, peer infos don't match! " & $pid & " != " & $conn.peerId) - else: - let pid = PeerId.init(remotePubKey) - if pid.isErr: - raise newException(NoiseHandshakeError, "Invalid remote peer id") - conn.peerId = pid.get() + raise newException(NoiseHandshakeError, "Noise handshake, peer id don't match! " & $pid & " != " & $targetPid) + conn.peerId = pid var tmp = NoiseConnection.new(conn, conn.peerId, conn.observedAddr) - if initiator: tmp.readCs = handshakeRes.cs2 tmp.writeCs = handshakeRes.cs1 diff --git a/libp2p/protocols/secure/secio.nim b/libp2p/protocols/secure/secio.nim index 46a05c71c..1ebea90b2 100644 --- a/libp2p/protocols/secure/secio.nim +++ b/libp2p/protocols/secure/secio.nim @@ -291,7 +291,7 @@ proc transactMessage(conn: Connection, await conn.write(msg) return await conn.readRawMessage() -method handshake*(s: Secio, conn: Connection, initiator: bool = false): Future[SecureConn] {.async.} = +method handshake*(s: Secio, conn: Connection, initiator: bool, peerId: Opt[PeerId]): Future[SecureConn] {.async.} = var localNonce: array[SecioNonceSize, byte] remoteNonce: seq[byte] @@ -342,9 +342,14 @@ method handshake*(s: Secio, conn: Connection, initiator: bool = false): Future[S remotePeerId = PeerId.init(remotePubkey).tryGet() - # TODO: PeerId check against supplied PeerId - if not initiator: - conn.peerId = remotePeerId + if peerId.isSome(): + let targetPid = peerId.get() + if not targetPid.validate(): + raise newException(SecioError, "Failed to validate expected peerId.") + + if remotePeerId != targetPid: + raise newException(SecioError, "Peer ids don't match!") + conn.peerId = remotePeerId let order = getOrder(remoteBytesPubkey, localNonce, localBytesPubkey, remoteNonce).tryGet() trace "Remote proposal", schemes = remoteExchanges, ciphers = remoteCiphers, diff --git a/libp2p/protocols/secure/secure.nim b/libp2p/protocols/secure/secure.nim index f841a5d04..0bdf85209 100644 --- a/libp2p/protocols/secure/secure.nim +++ b/libp2p/protocols/secure/secure.nim @@ -79,13 +79,15 @@ method getWrapped*(s: SecureConn): Connection = s.stream method handshake*(s: Secure, conn: Connection, - initiator: bool): Future[SecureConn] {.async, base.} = + initiator: bool, + peerId: Opt[PeerId]): Future[SecureConn] {.async, base.} = doAssert(false, "Not implemented!") proc handleConn(s: Secure, conn: Connection, - initiator: bool): Future[Connection] {.async.} = - var sconn = await s.handshake(conn, initiator) + initiator: bool, + peerId: Opt[PeerId]): Future[Connection] {.async.} = + var sconn = await s.handshake(conn, initiator, peerId) # mark connection bottom level transport direction # this is the safest place to do this # we require this information in for example gossipsub @@ -121,7 +123,7 @@ method init*(s: Secure) = try: # We don't need the result but we # definitely need to await the handshake - discard await s.handleConn(conn, false) + discard await s.handleConn(conn, false, Opt.none(PeerId)) trace "connection secured", conn except CancelledError as exc: warn "securing connection canceled", conn @@ -135,9 +137,10 @@ method init*(s: Secure) = method secure*(s: Secure, conn: Connection, - initiator: bool): + initiator: bool, + peerId: Opt[PeerId]): Future[Connection] {.base.} = - s.handleConn(conn, initiator) + s.handleConn(conn, initiator, peerId) method readOnce*(s: SecureConn, pbytes: pointer, diff --git a/libp2p/switch.nim b/libp2p/switch.nim index f6f292510..3bbdaa613 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -128,6 +128,13 @@ method connect*( s.dialer.connect(peerId, addrs, forceDial) +method connect*( + s: Switch, + addrs: seq[MultiAddress]): Future[PeerId] = + ## Connects to a peer and retrieve its PeerId + + s.dialer.connect(addrs) + method dial*( s: Switch, peerId: PeerId, diff --git a/libp2p/transports/transport.nim b/libp2p/transports/transport.nim index 951f8bf83..12d1a0803 100644 --- a/libp2p/transports/transport.nim +++ b/libp2p/transports/transport.nim @@ -87,12 +87,13 @@ method upgradeIncoming*( method upgradeOutgoing*( self: Transport, - conn: Connection): Future[Connection] {.base, gcsafe.} = + conn: Connection, + peerId: Opt[PeerId]): Future[Connection] {.base, gcsafe.} = ## base upgrade method that the transport uses to perform ## transport specific upgrades ## - self.upgrader.upgradeOutgoing(conn) + self.upgrader.upgradeOutgoing(conn, peerId) method handles*( self: Transport, diff --git a/libp2p/upgrademngrs/muxedupgrade.nim b/libp2p/upgrademngrs/muxedupgrade.nim index f60d0c159..030508c10 100644 --- a/libp2p/upgrademngrs/muxedupgrade.nim +++ b/libp2p/upgrademngrs/muxedupgrade.nim @@ -88,10 +88,11 @@ proc mux*( method upgradeOutgoing*( self: MuxedUpgrade, - conn: Connection): Future[Connection] {.async, gcsafe.} = + conn: Connection, + peerId: Opt[PeerId]): Future[Connection] {.async, gcsafe.} = trace "Upgrading outgoing connection", conn - let sconn = await self.secure(conn) # secure the connection + let sconn = await self.secure(conn, peerId) # secure the connection if isNil(sconn): raise newException(UpgradeFailedError, "unable to secure connection, stopping upgrade") @@ -129,7 +130,7 @@ method upgradeIncoming*( var cconn = conn try: - var sconn = await secure.secure(cconn, false) + var sconn = await secure.secure(cconn, false, Opt.none(PeerId)) if isNil(sconn): return diff --git a/libp2p/upgrademngrs/upgrade.nim b/libp2p/upgrademngrs/upgrade.nim index 781a074bc..c5733e658 100644 --- a/libp2p/upgrademngrs/upgrade.nim +++ b/libp2p/upgrademngrs/upgrade.nim @@ -47,12 +47,14 @@ method upgradeIncoming*( method upgradeOutgoing*( self: Upgrade, - conn: Connection): Future[Connection] {.base.} = + conn: Connection, + peerId: Opt[PeerId]): Future[Connection] {.base.} = doAssert(false, "Not implemented!") proc secure*( self: Upgrade, - conn: Connection): Future[Connection] {.async, gcsafe.} = + conn: Connection, + peerId: Opt[PeerId]): Future[Connection] {.async, gcsafe.} = if self.secureManagers.len <= 0: raise newException(UpgradeFailedError, "No secure managers registered!") @@ -67,7 +69,7 @@ proc secure*( # let's avoid duplicating checks but detect if it fails to do it properly doAssert(secureProtocol.len > 0) - return await secureProtocol[0].secure(conn, true) + return await secureProtocol[0].secure(conn, true, peerId) proc identify*( self: Upgrade, diff --git a/tests/testnoise.nim b/tests/testnoise.nim index eabe56b58..b0e785be3 100644 --- a/tests/testnoise.nim +++ b/tests/testnoise.nim @@ -104,7 +104,7 @@ suite "Noise": proc acceptHandler() {.async.} = let conn = await transport1.accept() - let sconn = await serverNoise.secure(conn, false) + let sconn = await serverNoise.secure(conn, false, Opt.none(PeerId)) try: await sconn.write("Hello!") finally: @@ -119,8 +119,7 @@ suite "Noise": clientNoise = Noise.new(rng, clientPrivKey, outgoing = true) conn = await transport2.dial(transport1.addrs[0]) - conn.peerId = serverInfo.peerId - let sconn = await clientNoise.secure(conn, true) + let sconn = await clientNoise.secure(conn, true, Opt.some(serverInfo.peerId)) var msg = newSeq[byte](6) await sconn.readExactly(addr msg[0], 6) @@ -149,7 +148,7 @@ suite "Noise": var conn: Connection try: conn = await transport1.accept() - discard await serverNoise.secure(conn, false) + discard await serverNoise.secure(conn, false, Opt.none(PeerId)) except CatchableError: discard finally: @@ -162,11 +161,10 @@ suite "Noise": clientInfo = PeerInfo.new(clientPrivKey, transport1.addrs) clientNoise = Noise.new(rng, clientPrivKey, outgoing = true, commonPrologue = @[1'u8, 2'u8, 3'u8]) conn = await transport2.dial(transport1.addrs[0]) - conn.peerId = serverInfo.peerId var sconn: Connection = nil expect(NoiseDecryptTagError): - sconn = await clientNoise.secure(conn, true) + sconn = await clientNoise.secure(conn, true, Opt.some(conn.peerId)) await conn.close() await handlerWait @@ -186,7 +184,7 @@ suite "Noise": proc acceptHandler() {.async, gcsafe.} = let conn = await transport1.accept() - let sconn = await serverNoise.secure(conn, false) + let sconn = await serverNoise.secure(conn, false, Opt.none(PeerId)) defer: await sconn.close() await conn.close() @@ -202,8 +200,7 @@ suite "Noise": clientInfo = PeerInfo.new(clientPrivKey, transport1.addrs) clientNoise = Noise.new(rng, clientPrivKey, outgoing = true) conn = await transport2.dial(transport1.addrs[0]) - conn.peerId = serverInfo.peerId - let sconn = await clientNoise.secure(conn, true) + let sconn = await clientNoise.secure(conn, true, Opt.some(serverInfo.peerId)) await sconn.write("Hello!") await acceptFut @@ -230,7 +227,7 @@ suite "Noise": proc acceptHandler() {.async, gcsafe.} = let conn = await transport1.accept() - let sconn = await serverNoise.secure(conn, false) + let sconn = await serverNoise.secure(conn, false, Opt.none(PeerId)) defer: await sconn.close() let msg = await sconn.readLp(1024*1024) @@ -244,8 +241,7 @@ suite "Noise": clientInfo = PeerInfo.new(clientPrivKey, transport1.addrs) clientNoise = Noise.new(rng, clientPrivKey, outgoing = true) conn = await transport2.dial(transport1.addrs[0]) - conn.peerId = serverInfo.peerId - let sconn = await clientNoise.secure(conn, true) + let sconn = await clientNoise.secure(conn, true, Opt.some(serverInfo.peerId)) await sconn.writeLp(hugePayload) await readTask diff --git a/tests/testswitch.nim b/tests/testswitch.nim index daaa0f088..3d47b1a5b 100644 --- a/tests/testswitch.nim +++ b/tests/testswitch.nim @@ -201,6 +201,20 @@ suite "Switch": check not switch1.isConnected(switch2.peerInfo.peerId) check not switch2.isConnected(switch1.peerInfo.peerId) + asyncTest "e2e connect to peer with unkown PeerId": + let switch1 = newStandardSwitch(secureManagers = [SecureProtocol.Noise]) + let switch2 = newStandardSwitch(secureManagers = [SecureProtocol.Noise]) + await switch1.start() + await switch2.start() + + check: (await switch2.connect(switch1.peerInfo.addrs)) == switch1.peerInfo.peerId + await switch2.disconnect(switch1.peerInfo.peerId) + + await allFuturesThrowing( + switch1.stop(), + switch2.stop() + ) + asyncTest "e2e should not leak on peer disconnect": let switch1 = newStandardSwitch() let switch2 = newStandardSwitch()