allow connection to a peer with unknown PeerId (#756)

Co-authored-by: Tanguy <tanguy@status.im>
This commit is contained in:
Jacek Sieka 2022-09-05 14:31:14 +02:00 committed by GitHub
parent 1de7508b64
commit dfbfbe6eb6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 108 additions and 69 deletions

View File

@ -31,7 +31,6 @@ const
type
Curve25519* = object
Curve25519Key* = array[Curve25519KeySize, byte]
pcuchar = ptr char
Curve25519Error* = enum
Curver25519GenError

View File

@ -31,6 +31,13 @@ method connect*(
doAssert(false, "Not implemented!")
method connect*(
self: Dial,
addrs: seq[MultiAddress]): Future[PeerId] {.async, base.} =
## Connects to a peer and retrieve its PeerId
doAssert(false, "Not implemented!")
method dial*(
self: Dial,
peerId: PeerId,

View File

@ -47,7 +47,7 @@ type
proc dialAndUpgrade(
self: Dialer,
peerId: PeerId,
peerId: Opt[PeerId],
addrs: seq[MultiAddress]):
Future[Connection] {.async.} =
debug "Dialing peer", peerId
@ -74,9 +74,6 @@ proc dialAndUpgrade(
libp2p_failed_dials.inc()
continue # Try the next address
# make sure to assign the peer to the connection
dialed.peerId = peerId
# also keep track of the connection's bottom unsafe transport direction
# required by gossipsub scoring
dialed.transportDir = Direction.Out
@ -84,7 +81,7 @@ proc dialAndUpgrade(
libp2p_successful_dials.inc()
let conn = try:
await transport.upgradeOutgoing(dialed)
await transport.upgradeOutgoing(dialed, peerId)
except CatchableError as exc:
# If we failed to establish the connection through one transport,
# we won't succeeded through another - no use in trying again
@ -101,20 +98,22 @@ proc dialAndUpgrade(
proc internalConnect(
self: Dialer,
peerId: PeerId,
peerId: Opt[PeerId],
addrs: seq[MultiAddress],
forceDial: bool):
Future[Connection] {.async.} =
if self.localPeerId == peerId:
if Opt.some(self.localPeerId) == peerId:
raise newException(CatchableError, "can't dial self!")
# Ensure there's only one in-flight attempt per peer
let lock = self.dialLock.mgetOrPut(peerId, newAsyncLock())
let lock = self.dialLock.mgetOrPut(peerId.get(default(PeerId)), newAsyncLock())
try:
await lock.acquire()
# Check if we have a connection already and try to reuse it
var conn = self.connManager.selectConn(peerId)
var conn =
if peerId.isSome: self.connManager.selectConn(peerId.get())
else: nil
if conn != nil:
if conn.atEof or conn.closed:
# This connection should already have been removed from the connection
@ -165,7 +164,15 @@ method connect*(
if self.connManager.connCount(peerId) > 0:
return
discard await self.internalConnect(peerId, addrs, forceDial)
discard await self.internalConnect(Opt.some(peerId), addrs, forceDial)
method connect*(
self: Dialer,
addrs: seq[MultiAddress],
): Future[PeerId] {.async.} =
## Connects to a peer and retrieve its PeerId
return (await self.internalConnect(Opt.none(PeerId), addrs, false)).peerId
proc negotiateStream(
self: Dialer,
@ -190,7 +197,7 @@ method tryDial*(
trace "Check if it can dial", peerId, addrs
try:
let conn = await self.dialAndUpgrade(peerId, addrs)
let conn = await self.dialAndUpgrade(Opt.some(peerId), addrs)
if conn.isNil():
raise newException(DialFailedError, "No valid multiaddress")
await conn.close()
@ -238,7 +245,7 @@ method dial*(
try:
trace "Dialing (new)", peerId, protos
conn = await self.internalConnect(peerId, addrs, forceDial)
conn = await self.internalConnect(Opt.some(peerId), addrs, forceDial)
trace "Opening stream", conn
stream = await self.connManager.getStream(conn)

View File

@ -38,7 +38,7 @@ const
# https://godoc.org/github.com/libp2p/go-libp2p-noise#pkg-constants
NoiseCodec* = "/noise"
PayloadString = "noise-libp2p-static-key:"
PayloadString = toBytes("noise-libp2p-static-key:")
ProtocolXXName = "Noise_XX_25519_ChaChaPoly_SHA256"
@ -339,7 +339,6 @@ proc handshakeXXOutbound(
hs = HandshakeState.init()
try:
hs.ss.mixHash(p.commonPrologue)
hs.s = p.noiseKeys
@ -445,7 +444,6 @@ method readMessage*(sconn: NoiseConnection): Future[seq[byte]] {.async.} =
dumpMessage(sconn, FlowDirection.Incoming, [])
trace "Received 0-length message", sconn
proc encryptFrame(
sconn: NoiseConnection,
cipherFrame: var openArray[byte],
@ -506,7 +504,7 @@ method write*(sconn: NoiseConnection, message: seq[byte]): Future[void] =
# sequencing issues
sconn.stream.write(cipherFrames)
method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureConn] {.async.} =
method handshake*(p: Noise, conn: Connection, initiator: bool, peerId: Opt[PeerId]): Future[SecureConn] {.async.} =
trace "Starting Noise handshake", conn, initiator
let timeout = conn.timeout
@ -515,7 +513,7 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon
# https://github.com/libp2p/specs/tree/master/noise#libp2p-data-in-handshake-messages
let
signedPayload = p.localPrivateKey.sign(
PayloadString.toBytes & p.noiseKeys.publicKey.getBytes).tryGet()
PayloadString & p.noiseKeys.publicKey.getBytes).tryGet()
var
libp2pProof = initProtoBuffer()
@ -538,11 +536,9 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon
remoteSig: Signature
remoteSigBytes: seq[byte]
let r1 = remoteProof.getField(1, remotePubKeyBytes)
let r2 = remoteProof.getField(2, remoteSigBytes)
if r1.isErr() or not(r1.get()):
if not remoteProof.getField(1, remotePubKeyBytes).valueOr(false):
raise newException(NoiseHandshakeError, "Failed to deserialize remote public key bytes. (initiator: " & $initiator & ")")
if r2.isErr() or not(r2.get()):
if not remoteProof.getField(2, remoteSigBytes).valueOr(false):
raise newException(NoiseHandshakeError, "Failed to deserialize remote signature bytes. (initiator: " & $initiator & ")")
if not remotePubKey.init(remotePubKeyBytes):
@ -550,33 +546,34 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon
if not remoteSig.init(remoteSigBytes):
raise newException(NoiseHandshakeError, "Failed to decode remote signature. (initiator: " & $initiator & ")")
let verifyPayload = PayloadString.toBytes & handshakeRes.rs.getBytes
let verifyPayload = PayloadString & handshakeRes.rs.getBytes
if not remoteSig.verify(verifyPayload, remotePubKey):
raise newException(NoiseHandshakeError, "Noise handshake signature verify failed.")
else:
trace "Remote signature verified", conn
if initiator:
let pid = PeerId.init(remotePubKey)
if not conn.peerId.validate():
raise newException(NoiseHandshakeError, "Failed to validate peerId.")
if pid.isErr or pid.get() != conn.peerId:
let pid = PeerId.init(remotePubKey).valueOr:
raise newException(NoiseHandshakeError, "Invalid remote peer id: " & $error)
trace "Remote peer id", pid = $pid
if peerId.isSome():
let targetPid = peerId.get()
if not targetPid.validate():
raise newException(NoiseHandshakeError, "Failed to validate expected peerId.")
if pid != targetPid:
var
failedKey: PublicKey
discard extractPublicKey(conn.peerId, failedKey)
debug "Noise handshake, peer infos don't match!",
discard extractPublicKey(targetPid, failedKey)
debug "Noise handshake, peer id doesn't match!",
initiator, dealt_peer = conn,
dealt_key = $failedKey, received_peer = $pid,
received_key = $remotePubKey
raise newException(NoiseHandshakeError, "Noise handshake, peer infos don't match! " & $pid & " != " & $conn.peerId)
else:
let pid = PeerId.init(remotePubKey)
if pid.isErr:
raise newException(NoiseHandshakeError, "Invalid remote peer id")
conn.peerId = pid.get()
raise newException(NoiseHandshakeError, "Noise handshake, peer id don't match! " & $pid & " != " & $targetPid)
conn.peerId = pid
var tmp = NoiseConnection.new(conn, conn.peerId, conn.observedAddr)
if initiator:
tmp.readCs = handshakeRes.cs2
tmp.writeCs = handshakeRes.cs1

View File

@ -291,7 +291,7 @@ proc transactMessage(conn: Connection,
await conn.write(msg)
return await conn.readRawMessage()
method handshake*(s: Secio, conn: Connection, initiator: bool = false): Future[SecureConn] {.async.} =
method handshake*(s: Secio, conn: Connection, initiator: bool, peerId: Opt[PeerId]): Future[SecureConn] {.async.} =
var
localNonce: array[SecioNonceSize, byte]
remoteNonce: seq[byte]
@ -342,8 +342,13 @@ method handshake*(s: Secio, conn: Connection, initiator: bool = false): Future[S
remotePeerId = PeerId.init(remotePubkey).tryGet()
# TODO: PeerId check against supplied PeerId
if not initiator:
if peerId.isSome():
let targetPid = peerId.get()
if not targetPid.validate():
raise newException(SecioError, "Failed to validate expected peerId.")
if remotePeerId != targetPid:
raise newException(SecioError, "Peer ids don't match!")
conn.peerId = remotePeerId
let order = getOrder(remoteBytesPubkey, localNonce, localBytesPubkey,
remoteNonce).tryGet()

View File

@ -79,13 +79,15 @@ method getWrapped*(s: SecureConn): Connection = s.stream
method handshake*(s: Secure,
conn: Connection,
initiator: bool): Future[SecureConn] {.async, base.} =
initiator: bool,
peerId: Opt[PeerId]): Future[SecureConn] {.async, base.} =
doAssert(false, "Not implemented!")
proc handleConn(s: Secure,
conn: Connection,
initiator: bool): Future[Connection] {.async.} =
var sconn = await s.handshake(conn, initiator)
initiator: bool,
peerId: Opt[PeerId]): Future[Connection] {.async.} =
var sconn = await s.handshake(conn, initiator, peerId)
# mark connection bottom level transport direction
# this is the safest place to do this
# we require this information in for example gossipsub
@ -121,7 +123,7 @@ method init*(s: Secure) =
try:
# We don't need the result but we
# definitely need to await the handshake
discard await s.handleConn(conn, false)
discard await s.handleConn(conn, false, Opt.none(PeerId))
trace "connection secured", conn
except CancelledError as exc:
warn "securing connection canceled", conn
@ -135,9 +137,10 @@ method init*(s: Secure) =
method secure*(s: Secure,
conn: Connection,
initiator: bool):
initiator: bool,
peerId: Opt[PeerId]):
Future[Connection] {.base.} =
s.handleConn(conn, initiator)
s.handleConn(conn, initiator, peerId)
method readOnce*(s: SecureConn,
pbytes: pointer,

View File

@ -128,6 +128,13 @@ method connect*(
s.dialer.connect(peerId, addrs, forceDial)
method connect*(
s: Switch,
addrs: seq[MultiAddress]): Future[PeerId] =
## Connects to a peer and retrieve its PeerId
s.dialer.connect(addrs)
method dial*(
s: Switch,
peerId: PeerId,

View File

@ -87,12 +87,13 @@ method upgradeIncoming*(
method upgradeOutgoing*(
self: Transport,
conn: Connection): Future[Connection] {.base, gcsafe.} =
conn: Connection,
peerId: Opt[PeerId]): Future[Connection] {.base, gcsafe.} =
## base upgrade method that the transport uses to perform
## transport specific upgrades
##
self.upgrader.upgradeOutgoing(conn)
self.upgrader.upgradeOutgoing(conn, peerId)
method handles*(
self: Transport,

View File

@ -88,10 +88,11 @@ proc mux*(
method upgradeOutgoing*(
self: MuxedUpgrade,
conn: Connection): Future[Connection] {.async, gcsafe.} =
conn: Connection,
peerId: Opt[PeerId]): Future[Connection] {.async, gcsafe.} =
trace "Upgrading outgoing connection", conn
let sconn = await self.secure(conn) # secure the connection
let sconn = await self.secure(conn, peerId) # secure the connection
if isNil(sconn):
raise newException(UpgradeFailedError,
"unable to secure connection, stopping upgrade")
@ -129,7 +130,7 @@ method upgradeIncoming*(
var cconn = conn
try:
var sconn = await secure.secure(cconn, false)
var sconn = await secure.secure(cconn, false, Opt.none(PeerId))
if isNil(sconn):
return

View File

@ -47,12 +47,14 @@ method upgradeIncoming*(
method upgradeOutgoing*(
self: Upgrade,
conn: Connection): Future[Connection] {.base.} =
conn: Connection,
peerId: Opt[PeerId]): Future[Connection] {.base.} =
doAssert(false, "Not implemented!")
proc secure*(
self: Upgrade,
conn: Connection): Future[Connection] {.async, gcsafe.} =
conn: Connection,
peerId: Opt[PeerId]): Future[Connection] {.async, gcsafe.} =
if self.secureManagers.len <= 0:
raise newException(UpgradeFailedError, "No secure managers registered!")
@ -67,7 +69,7 @@ proc secure*(
# let's avoid duplicating checks but detect if it fails to do it properly
doAssert(secureProtocol.len > 0)
return await secureProtocol[0].secure(conn, true)
return await secureProtocol[0].secure(conn, true, peerId)
proc identify*(
self: Upgrade,

View File

@ -104,7 +104,7 @@ suite "Noise":
proc acceptHandler() {.async.} =
let conn = await transport1.accept()
let sconn = await serverNoise.secure(conn, false)
let sconn = await serverNoise.secure(conn, false, Opt.none(PeerId))
try:
await sconn.write("Hello!")
finally:
@ -119,8 +119,7 @@ suite "Noise":
clientNoise = Noise.new(rng, clientPrivKey, outgoing = true)
conn = await transport2.dial(transport1.addrs[0])
conn.peerId = serverInfo.peerId
let sconn = await clientNoise.secure(conn, true)
let sconn = await clientNoise.secure(conn, true, Opt.some(serverInfo.peerId))
var msg = newSeq[byte](6)
await sconn.readExactly(addr msg[0], 6)
@ -149,7 +148,7 @@ suite "Noise":
var conn: Connection
try:
conn = await transport1.accept()
discard await serverNoise.secure(conn, false)
discard await serverNoise.secure(conn, false, Opt.none(PeerId))
except CatchableError:
discard
finally:
@ -162,11 +161,10 @@ suite "Noise":
clientInfo = PeerInfo.new(clientPrivKey, transport1.addrs)
clientNoise = Noise.new(rng, clientPrivKey, outgoing = true, commonPrologue = @[1'u8, 2'u8, 3'u8])
conn = await transport2.dial(transport1.addrs[0])
conn.peerId = serverInfo.peerId
var sconn: Connection = nil
expect(NoiseDecryptTagError):
sconn = await clientNoise.secure(conn, true)
sconn = await clientNoise.secure(conn, true, Opt.some(conn.peerId))
await conn.close()
await handlerWait
@ -186,7 +184,7 @@ suite "Noise":
proc acceptHandler() {.async, gcsafe.} =
let conn = await transport1.accept()
let sconn = await serverNoise.secure(conn, false)
let sconn = await serverNoise.secure(conn, false, Opt.none(PeerId))
defer:
await sconn.close()
await conn.close()
@ -202,8 +200,7 @@ suite "Noise":
clientInfo = PeerInfo.new(clientPrivKey, transport1.addrs)
clientNoise = Noise.new(rng, clientPrivKey, outgoing = true)
conn = await transport2.dial(transport1.addrs[0])
conn.peerId = serverInfo.peerId
let sconn = await clientNoise.secure(conn, true)
let sconn = await clientNoise.secure(conn, true, Opt.some(serverInfo.peerId))
await sconn.write("Hello!")
await acceptFut
@ -230,7 +227,7 @@ suite "Noise":
proc acceptHandler() {.async, gcsafe.} =
let conn = await transport1.accept()
let sconn = await serverNoise.secure(conn, false)
let sconn = await serverNoise.secure(conn, false, Opt.none(PeerId))
defer:
await sconn.close()
let msg = await sconn.readLp(1024*1024)
@ -244,8 +241,7 @@ suite "Noise":
clientInfo = PeerInfo.new(clientPrivKey, transport1.addrs)
clientNoise = Noise.new(rng, clientPrivKey, outgoing = true)
conn = await transport2.dial(transport1.addrs[0])
conn.peerId = serverInfo.peerId
let sconn = await clientNoise.secure(conn, true)
let sconn = await clientNoise.secure(conn, true, Opt.some(serverInfo.peerId))
await sconn.writeLp(hugePayload)
await readTask

View File

@ -201,6 +201,20 @@ suite "Switch":
check not switch1.isConnected(switch2.peerInfo.peerId)
check not switch2.isConnected(switch1.peerInfo.peerId)
asyncTest "e2e connect to peer with unkown PeerId":
let switch1 = newStandardSwitch(secureManagers = [SecureProtocol.Noise])
let switch2 = newStandardSwitch(secureManagers = [SecureProtocol.Noise])
await switch1.start()
await switch2.start()
check: (await switch2.connect(switch1.peerInfo.addrs)) == switch1.peerInfo.peerId
await switch2.disconnect(switch1.peerInfo.peerId)
await allFuturesThrowing(
switch1.stop(),
switch2.stop()
)
asyncTest "e2e should not leak on peer disconnect":
let switch1 = newStandardSwitch()
let switch2 = newStandardSwitch()