allow connection to a peer with unknown PeerId (#756)

Co-authored-by: Tanguy <tanguy@status.im>
This commit is contained in:
Jacek Sieka 2022-09-05 14:31:14 +02:00 committed by GitHub
parent 1de7508b64
commit dfbfbe6eb6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 108 additions and 69 deletions

View File

@ -31,7 +31,6 @@ const
type type
Curve25519* = object Curve25519* = object
Curve25519Key* = array[Curve25519KeySize, byte] Curve25519Key* = array[Curve25519KeySize, byte]
pcuchar = ptr char
Curve25519Error* = enum Curve25519Error* = enum
Curver25519GenError Curver25519GenError

View File

@ -31,6 +31,13 @@ method connect*(
doAssert(false, "Not implemented!") doAssert(false, "Not implemented!")
method connect*(
self: Dial,
addrs: seq[MultiAddress]): Future[PeerId] {.async, base.} =
## Connects to a peer and retrieve its PeerId
doAssert(false, "Not implemented!")
method dial*( method dial*(
self: Dial, self: Dial,
peerId: PeerId, peerId: PeerId,

View File

@ -47,7 +47,7 @@ type
proc dialAndUpgrade( proc dialAndUpgrade(
self: Dialer, self: Dialer,
peerId: PeerId, peerId: Opt[PeerId],
addrs: seq[MultiAddress]): addrs: seq[MultiAddress]):
Future[Connection] {.async.} = Future[Connection] {.async.} =
debug "Dialing peer", peerId debug "Dialing peer", peerId
@ -74,9 +74,6 @@ proc dialAndUpgrade(
libp2p_failed_dials.inc() libp2p_failed_dials.inc()
continue # Try the next address continue # Try the next address
# make sure to assign the peer to the connection
dialed.peerId = peerId
# also keep track of the connection's bottom unsafe transport direction # also keep track of the connection's bottom unsafe transport direction
# required by gossipsub scoring # required by gossipsub scoring
dialed.transportDir = Direction.Out dialed.transportDir = Direction.Out
@ -84,7 +81,7 @@ proc dialAndUpgrade(
libp2p_successful_dials.inc() libp2p_successful_dials.inc()
let conn = try: let conn = try:
await transport.upgradeOutgoing(dialed) await transport.upgradeOutgoing(dialed, peerId)
except CatchableError as exc: except CatchableError as exc:
# If we failed to establish the connection through one transport, # If we failed to establish the connection through one transport,
# we won't succeeded through another - no use in trying again # we won't succeeded through another - no use in trying again
@ -101,20 +98,22 @@ proc dialAndUpgrade(
proc internalConnect( proc internalConnect(
self: Dialer, self: Dialer,
peerId: PeerId, peerId: Opt[PeerId],
addrs: seq[MultiAddress], addrs: seq[MultiAddress],
forceDial: bool): forceDial: bool):
Future[Connection] {.async.} = Future[Connection] {.async.} =
if self.localPeerId == peerId: if Opt.some(self.localPeerId) == peerId:
raise newException(CatchableError, "can't dial self!") raise newException(CatchableError, "can't dial self!")
# Ensure there's only one in-flight attempt per peer # Ensure there's only one in-flight attempt per peer
let lock = self.dialLock.mgetOrPut(peerId, newAsyncLock()) let lock = self.dialLock.mgetOrPut(peerId.get(default(PeerId)), newAsyncLock())
try: try:
await lock.acquire() await lock.acquire()
# Check if we have a connection already and try to reuse it # Check if we have a connection already and try to reuse it
var conn = self.connManager.selectConn(peerId) var conn =
if peerId.isSome: self.connManager.selectConn(peerId.get())
else: nil
if conn != nil: if conn != nil:
if conn.atEof or conn.closed: if conn.atEof or conn.closed:
# This connection should already have been removed from the connection # This connection should already have been removed from the connection
@ -165,7 +164,15 @@ method connect*(
if self.connManager.connCount(peerId) > 0: if self.connManager.connCount(peerId) > 0:
return return
discard await self.internalConnect(peerId, addrs, forceDial) discard await self.internalConnect(Opt.some(peerId), addrs, forceDial)
method connect*(
self: Dialer,
addrs: seq[MultiAddress],
): Future[PeerId] {.async.} =
## Connects to a peer and retrieve its PeerId
return (await self.internalConnect(Opt.none(PeerId), addrs, false)).peerId
proc negotiateStream( proc negotiateStream(
self: Dialer, self: Dialer,
@ -190,7 +197,7 @@ method tryDial*(
trace "Check if it can dial", peerId, addrs trace "Check if it can dial", peerId, addrs
try: try:
let conn = await self.dialAndUpgrade(peerId, addrs) let conn = await self.dialAndUpgrade(Opt.some(peerId), addrs)
if conn.isNil(): if conn.isNil():
raise newException(DialFailedError, "No valid multiaddress") raise newException(DialFailedError, "No valid multiaddress")
await conn.close() await conn.close()
@ -238,7 +245,7 @@ method dial*(
try: try:
trace "Dialing (new)", peerId, protos trace "Dialing (new)", peerId, protos
conn = await self.internalConnect(peerId, addrs, forceDial) conn = await self.internalConnect(Opt.some(peerId), addrs, forceDial)
trace "Opening stream", conn trace "Opening stream", conn
stream = await self.connManager.getStream(conn) stream = await self.connManager.getStream(conn)

View File

@ -38,7 +38,7 @@ const
# https://godoc.org/github.com/libp2p/go-libp2p-noise#pkg-constants # https://godoc.org/github.com/libp2p/go-libp2p-noise#pkg-constants
NoiseCodec* = "/noise" NoiseCodec* = "/noise"
PayloadString = "noise-libp2p-static-key:" PayloadString = toBytes("noise-libp2p-static-key:")
ProtocolXXName = "Noise_XX_25519_ChaChaPoly_SHA256" ProtocolXXName = "Noise_XX_25519_ChaChaPoly_SHA256"
@ -339,7 +339,6 @@ proc handshakeXXOutbound(
hs = HandshakeState.init() hs = HandshakeState.init()
try: try:
hs.ss.mixHash(p.commonPrologue) hs.ss.mixHash(p.commonPrologue)
hs.s = p.noiseKeys hs.s = p.noiseKeys
@ -445,7 +444,6 @@ method readMessage*(sconn: NoiseConnection): Future[seq[byte]] {.async.} =
dumpMessage(sconn, FlowDirection.Incoming, []) dumpMessage(sconn, FlowDirection.Incoming, [])
trace "Received 0-length message", sconn trace "Received 0-length message", sconn
proc encryptFrame( proc encryptFrame(
sconn: NoiseConnection, sconn: NoiseConnection,
cipherFrame: var openArray[byte], cipherFrame: var openArray[byte],
@ -506,7 +504,7 @@ method write*(sconn: NoiseConnection, message: seq[byte]): Future[void] =
# sequencing issues # sequencing issues
sconn.stream.write(cipherFrames) sconn.stream.write(cipherFrames)
method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureConn] {.async.} = method handshake*(p: Noise, conn: Connection, initiator: bool, peerId: Opt[PeerId]): Future[SecureConn] {.async.} =
trace "Starting Noise handshake", conn, initiator trace "Starting Noise handshake", conn, initiator
let timeout = conn.timeout let timeout = conn.timeout
@ -515,7 +513,7 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon
# https://github.com/libp2p/specs/tree/master/noise#libp2p-data-in-handshake-messages # https://github.com/libp2p/specs/tree/master/noise#libp2p-data-in-handshake-messages
let let
signedPayload = p.localPrivateKey.sign( signedPayload = p.localPrivateKey.sign(
PayloadString.toBytes & p.noiseKeys.publicKey.getBytes).tryGet() PayloadString & p.noiseKeys.publicKey.getBytes).tryGet()
var var
libp2pProof = initProtoBuffer() libp2pProof = initProtoBuffer()
@ -538,11 +536,9 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon
remoteSig: Signature remoteSig: Signature
remoteSigBytes: seq[byte] remoteSigBytes: seq[byte]
let r1 = remoteProof.getField(1, remotePubKeyBytes) if not remoteProof.getField(1, remotePubKeyBytes).valueOr(false):
let r2 = remoteProof.getField(2, remoteSigBytes)
if r1.isErr() or not(r1.get()):
raise newException(NoiseHandshakeError, "Failed to deserialize remote public key bytes. (initiator: " & $initiator & ")") raise newException(NoiseHandshakeError, "Failed to deserialize remote public key bytes. (initiator: " & $initiator & ")")
if r2.isErr() or not(r2.get()): if not remoteProof.getField(2, remoteSigBytes).valueOr(false):
raise newException(NoiseHandshakeError, "Failed to deserialize remote signature bytes. (initiator: " & $initiator & ")") raise newException(NoiseHandshakeError, "Failed to deserialize remote signature bytes. (initiator: " & $initiator & ")")
if not remotePubKey.init(remotePubKeyBytes): if not remotePubKey.init(remotePubKeyBytes):
@ -550,33 +546,34 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon
if not remoteSig.init(remoteSigBytes): if not remoteSig.init(remoteSigBytes):
raise newException(NoiseHandshakeError, "Failed to decode remote signature. (initiator: " & $initiator & ")") raise newException(NoiseHandshakeError, "Failed to decode remote signature. (initiator: " & $initiator & ")")
let verifyPayload = PayloadString.toBytes & handshakeRes.rs.getBytes let verifyPayload = PayloadString & handshakeRes.rs.getBytes
if not remoteSig.verify(verifyPayload, remotePubKey): if not remoteSig.verify(verifyPayload, remotePubKey):
raise newException(NoiseHandshakeError, "Noise handshake signature verify failed.") raise newException(NoiseHandshakeError, "Noise handshake signature verify failed.")
else: else:
trace "Remote signature verified", conn trace "Remote signature verified", conn
if initiator: let pid = PeerId.init(remotePubKey).valueOr:
let pid = PeerId.init(remotePubKey) raise newException(NoiseHandshakeError, "Invalid remote peer id: " & $error)
if not conn.peerId.validate():
raise newException(NoiseHandshakeError, "Failed to validate peerId.") trace "Remote peer id", pid = $pid
if pid.isErr or pid.get() != conn.peerId:
if peerId.isSome():
let targetPid = peerId.get()
if not targetPid.validate():
raise newException(NoiseHandshakeError, "Failed to validate expected peerId.")
if pid != targetPid:
var var
failedKey: PublicKey failedKey: PublicKey
discard extractPublicKey(conn.peerId, failedKey) discard extractPublicKey(targetPid, failedKey)
debug "Noise handshake, peer infos don't match!", debug "Noise handshake, peer id doesn't match!",
initiator, dealt_peer = conn, initiator, dealt_peer = conn,
dealt_key = $failedKey, received_peer = $pid, dealt_key = $failedKey, received_peer = $pid,
received_key = $remotePubKey received_key = $remotePubKey
raise newException(NoiseHandshakeError, "Noise handshake, peer infos don't match! " & $pid & " != " & $conn.peerId) raise newException(NoiseHandshakeError, "Noise handshake, peer id don't match! " & $pid & " != " & $targetPid)
else: conn.peerId = pid
let pid = PeerId.init(remotePubKey)
if pid.isErr:
raise newException(NoiseHandshakeError, "Invalid remote peer id")
conn.peerId = pid.get()
var tmp = NoiseConnection.new(conn, conn.peerId, conn.observedAddr) var tmp = NoiseConnection.new(conn, conn.peerId, conn.observedAddr)
if initiator: if initiator:
tmp.readCs = handshakeRes.cs2 tmp.readCs = handshakeRes.cs2
tmp.writeCs = handshakeRes.cs1 tmp.writeCs = handshakeRes.cs1

View File

@ -291,7 +291,7 @@ proc transactMessage(conn: Connection,
await conn.write(msg) await conn.write(msg)
return await conn.readRawMessage() return await conn.readRawMessage()
method handshake*(s: Secio, conn: Connection, initiator: bool = false): Future[SecureConn] {.async.} = method handshake*(s: Secio, conn: Connection, initiator: bool, peerId: Opt[PeerId]): Future[SecureConn] {.async.} =
var var
localNonce: array[SecioNonceSize, byte] localNonce: array[SecioNonceSize, byte]
remoteNonce: seq[byte] remoteNonce: seq[byte]
@ -342,8 +342,13 @@ method handshake*(s: Secio, conn: Connection, initiator: bool = false): Future[S
remotePeerId = PeerId.init(remotePubkey).tryGet() remotePeerId = PeerId.init(remotePubkey).tryGet()
# TODO: PeerId check against supplied PeerId if peerId.isSome():
if not initiator: let targetPid = peerId.get()
if not targetPid.validate():
raise newException(SecioError, "Failed to validate expected peerId.")
if remotePeerId != targetPid:
raise newException(SecioError, "Peer ids don't match!")
conn.peerId = remotePeerId conn.peerId = remotePeerId
let order = getOrder(remoteBytesPubkey, localNonce, localBytesPubkey, let order = getOrder(remoteBytesPubkey, localNonce, localBytesPubkey,
remoteNonce).tryGet() remoteNonce).tryGet()

View File

@ -79,13 +79,15 @@ method getWrapped*(s: SecureConn): Connection = s.stream
method handshake*(s: Secure, method handshake*(s: Secure,
conn: Connection, conn: Connection,
initiator: bool): Future[SecureConn] {.async, base.} = initiator: bool,
peerId: Opt[PeerId]): Future[SecureConn] {.async, base.} =
doAssert(false, "Not implemented!") doAssert(false, "Not implemented!")
proc handleConn(s: Secure, proc handleConn(s: Secure,
conn: Connection, conn: Connection,
initiator: bool): Future[Connection] {.async.} = initiator: bool,
var sconn = await s.handshake(conn, initiator) peerId: Opt[PeerId]): Future[Connection] {.async.} =
var sconn = await s.handshake(conn, initiator, peerId)
# mark connection bottom level transport direction # mark connection bottom level transport direction
# this is the safest place to do this # this is the safest place to do this
# we require this information in for example gossipsub # we require this information in for example gossipsub
@ -121,7 +123,7 @@ method init*(s: Secure) =
try: try:
# We don't need the result but we # We don't need the result but we
# definitely need to await the handshake # definitely need to await the handshake
discard await s.handleConn(conn, false) discard await s.handleConn(conn, false, Opt.none(PeerId))
trace "connection secured", conn trace "connection secured", conn
except CancelledError as exc: except CancelledError as exc:
warn "securing connection canceled", conn warn "securing connection canceled", conn
@ -135,9 +137,10 @@ method init*(s: Secure) =
method secure*(s: Secure, method secure*(s: Secure,
conn: Connection, conn: Connection,
initiator: bool): initiator: bool,
peerId: Opt[PeerId]):
Future[Connection] {.base.} = Future[Connection] {.base.} =
s.handleConn(conn, initiator) s.handleConn(conn, initiator, peerId)
method readOnce*(s: SecureConn, method readOnce*(s: SecureConn,
pbytes: pointer, pbytes: pointer,

View File

@ -128,6 +128,13 @@ method connect*(
s.dialer.connect(peerId, addrs, forceDial) s.dialer.connect(peerId, addrs, forceDial)
method connect*(
s: Switch,
addrs: seq[MultiAddress]): Future[PeerId] =
## Connects to a peer and retrieve its PeerId
s.dialer.connect(addrs)
method dial*( method dial*(
s: Switch, s: Switch,
peerId: PeerId, peerId: PeerId,

View File

@ -87,12 +87,13 @@ method upgradeIncoming*(
method upgradeOutgoing*( method upgradeOutgoing*(
self: Transport, self: Transport,
conn: Connection): Future[Connection] {.base, gcsafe.} = conn: Connection,
peerId: Opt[PeerId]): Future[Connection] {.base, gcsafe.} =
## base upgrade method that the transport uses to perform ## base upgrade method that the transport uses to perform
## transport specific upgrades ## transport specific upgrades
## ##
self.upgrader.upgradeOutgoing(conn) self.upgrader.upgradeOutgoing(conn, peerId)
method handles*( method handles*(
self: Transport, self: Transport,

View File

@ -88,10 +88,11 @@ proc mux*(
method upgradeOutgoing*( method upgradeOutgoing*(
self: MuxedUpgrade, self: MuxedUpgrade,
conn: Connection): Future[Connection] {.async, gcsafe.} = conn: Connection,
peerId: Opt[PeerId]): Future[Connection] {.async, gcsafe.} =
trace "Upgrading outgoing connection", conn trace "Upgrading outgoing connection", conn
let sconn = await self.secure(conn) # secure the connection let sconn = await self.secure(conn, peerId) # secure the connection
if isNil(sconn): if isNil(sconn):
raise newException(UpgradeFailedError, raise newException(UpgradeFailedError,
"unable to secure connection, stopping upgrade") "unable to secure connection, stopping upgrade")
@ -129,7 +130,7 @@ method upgradeIncoming*(
var cconn = conn var cconn = conn
try: try:
var sconn = await secure.secure(cconn, false) var sconn = await secure.secure(cconn, false, Opt.none(PeerId))
if isNil(sconn): if isNil(sconn):
return return

View File

@ -47,12 +47,14 @@ method upgradeIncoming*(
method upgradeOutgoing*( method upgradeOutgoing*(
self: Upgrade, self: Upgrade,
conn: Connection): Future[Connection] {.base.} = conn: Connection,
peerId: Opt[PeerId]): Future[Connection] {.base.} =
doAssert(false, "Not implemented!") doAssert(false, "Not implemented!")
proc secure*( proc secure*(
self: Upgrade, self: Upgrade,
conn: Connection): Future[Connection] {.async, gcsafe.} = conn: Connection,
peerId: Opt[PeerId]): Future[Connection] {.async, gcsafe.} =
if self.secureManagers.len <= 0: if self.secureManagers.len <= 0:
raise newException(UpgradeFailedError, "No secure managers registered!") raise newException(UpgradeFailedError, "No secure managers registered!")
@ -67,7 +69,7 @@ proc secure*(
# let's avoid duplicating checks but detect if it fails to do it properly # let's avoid duplicating checks but detect if it fails to do it properly
doAssert(secureProtocol.len > 0) doAssert(secureProtocol.len > 0)
return await secureProtocol[0].secure(conn, true) return await secureProtocol[0].secure(conn, true, peerId)
proc identify*( proc identify*(
self: Upgrade, self: Upgrade,

View File

@ -104,7 +104,7 @@ suite "Noise":
proc acceptHandler() {.async.} = proc acceptHandler() {.async.} =
let conn = await transport1.accept() let conn = await transport1.accept()
let sconn = await serverNoise.secure(conn, false) let sconn = await serverNoise.secure(conn, false, Opt.none(PeerId))
try: try:
await sconn.write("Hello!") await sconn.write("Hello!")
finally: finally:
@ -119,8 +119,7 @@ suite "Noise":
clientNoise = Noise.new(rng, clientPrivKey, outgoing = true) clientNoise = Noise.new(rng, clientPrivKey, outgoing = true)
conn = await transport2.dial(transport1.addrs[0]) conn = await transport2.dial(transport1.addrs[0])
conn.peerId = serverInfo.peerId let sconn = await clientNoise.secure(conn, true, Opt.some(serverInfo.peerId))
let sconn = await clientNoise.secure(conn, true)
var msg = newSeq[byte](6) var msg = newSeq[byte](6)
await sconn.readExactly(addr msg[0], 6) await sconn.readExactly(addr msg[0], 6)
@ -149,7 +148,7 @@ suite "Noise":
var conn: Connection var conn: Connection
try: try:
conn = await transport1.accept() conn = await transport1.accept()
discard await serverNoise.secure(conn, false) discard await serverNoise.secure(conn, false, Opt.none(PeerId))
except CatchableError: except CatchableError:
discard discard
finally: finally:
@ -162,11 +161,10 @@ suite "Noise":
clientInfo = PeerInfo.new(clientPrivKey, transport1.addrs) clientInfo = PeerInfo.new(clientPrivKey, transport1.addrs)
clientNoise = Noise.new(rng, clientPrivKey, outgoing = true, commonPrologue = @[1'u8, 2'u8, 3'u8]) clientNoise = Noise.new(rng, clientPrivKey, outgoing = true, commonPrologue = @[1'u8, 2'u8, 3'u8])
conn = await transport2.dial(transport1.addrs[0]) conn = await transport2.dial(transport1.addrs[0])
conn.peerId = serverInfo.peerId
var sconn: Connection = nil var sconn: Connection = nil
expect(NoiseDecryptTagError): expect(NoiseDecryptTagError):
sconn = await clientNoise.secure(conn, true) sconn = await clientNoise.secure(conn, true, Opt.some(conn.peerId))
await conn.close() await conn.close()
await handlerWait await handlerWait
@ -186,7 +184,7 @@ suite "Noise":
proc acceptHandler() {.async, gcsafe.} = proc acceptHandler() {.async, gcsafe.} =
let conn = await transport1.accept() let conn = await transport1.accept()
let sconn = await serverNoise.secure(conn, false) let sconn = await serverNoise.secure(conn, false, Opt.none(PeerId))
defer: defer:
await sconn.close() await sconn.close()
await conn.close() await conn.close()
@ -202,8 +200,7 @@ suite "Noise":
clientInfo = PeerInfo.new(clientPrivKey, transport1.addrs) clientInfo = PeerInfo.new(clientPrivKey, transport1.addrs)
clientNoise = Noise.new(rng, clientPrivKey, outgoing = true) clientNoise = Noise.new(rng, clientPrivKey, outgoing = true)
conn = await transport2.dial(transport1.addrs[0]) conn = await transport2.dial(transport1.addrs[0])
conn.peerId = serverInfo.peerId let sconn = await clientNoise.secure(conn, true, Opt.some(serverInfo.peerId))
let sconn = await clientNoise.secure(conn, true)
await sconn.write("Hello!") await sconn.write("Hello!")
await acceptFut await acceptFut
@ -230,7 +227,7 @@ suite "Noise":
proc acceptHandler() {.async, gcsafe.} = proc acceptHandler() {.async, gcsafe.} =
let conn = await transport1.accept() let conn = await transport1.accept()
let sconn = await serverNoise.secure(conn, false) let sconn = await serverNoise.secure(conn, false, Opt.none(PeerId))
defer: defer:
await sconn.close() await sconn.close()
let msg = await sconn.readLp(1024*1024) let msg = await sconn.readLp(1024*1024)
@ -244,8 +241,7 @@ suite "Noise":
clientInfo = PeerInfo.new(clientPrivKey, transport1.addrs) clientInfo = PeerInfo.new(clientPrivKey, transport1.addrs)
clientNoise = Noise.new(rng, clientPrivKey, outgoing = true) clientNoise = Noise.new(rng, clientPrivKey, outgoing = true)
conn = await transport2.dial(transport1.addrs[0]) conn = await transport2.dial(transport1.addrs[0])
conn.peerId = serverInfo.peerId let sconn = await clientNoise.secure(conn, true, Opt.some(serverInfo.peerId))
let sconn = await clientNoise.secure(conn, true)
await sconn.writeLp(hugePayload) await sconn.writeLp(hugePayload)
await readTask await readTask

View File

@ -201,6 +201,20 @@ suite "Switch":
check not switch1.isConnected(switch2.peerInfo.peerId) check not switch1.isConnected(switch2.peerInfo.peerId)
check not switch2.isConnected(switch1.peerInfo.peerId) check not switch2.isConnected(switch1.peerInfo.peerId)
asyncTest "e2e connect to peer with unkown PeerId":
let switch1 = newStandardSwitch(secureManagers = [SecureProtocol.Noise])
let switch2 = newStandardSwitch(secureManagers = [SecureProtocol.Noise])
await switch1.start()
await switch2.start()
check: (await switch2.connect(switch1.peerInfo.addrs)) == switch1.peerInfo.peerId
await switch2.disconnect(switch1.peerInfo.peerId)
await allFuturesThrowing(
switch1.stop(),
switch2.stop()
)
asyncTest "e2e should not leak on peer disconnect": asyncTest "e2e should not leak on peer disconnect":
let switch1 = newStandardSwitch() let switch1 = newStandardSwitch()
let switch2 = newStandardSwitch() let switch2 = newStandardSwitch()