allow connection to a peer with unknown PeerId (#756)
Co-authored-by: Tanguy <tanguy@status.im>
This commit is contained in:
parent
1de7508b64
commit
dfbfbe6eb6
|
@ -31,7 +31,6 @@ const
|
|||
type
|
||||
Curve25519* = object
|
||||
Curve25519Key* = array[Curve25519KeySize, byte]
|
||||
pcuchar = ptr char
|
||||
Curve25519Error* = enum
|
||||
Curver25519GenError
|
||||
|
||||
|
|
|
@ -31,6 +31,13 @@ method connect*(
|
|||
|
||||
doAssert(false, "Not implemented!")
|
||||
|
||||
method connect*(
|
||||
self: Dial,
|
||||
addrs: seq[MultiAddress]): Future[PeerId] {.async, base.} =
|
||||
## Connects to a peer and retrieve its PeerId
|
||||
|
||||
doAssert(false, "Not implemented!")
|
||||
|
||||
method dial*(
|
||||
self: Dial,
|
||||
peerId: PeerId,
|
||||
|
|
|
@ -47,7 +47,7 @@ type
|
|||
|
||||
proc dialAndUpgrade(
|
||||
self: Dialer,
|
||||
peerId: PeerId,
|
||||
peerId: Opt[PeerId],
|
||||
addrs: seq[MultiAddress]):
|
||||
Future[Connection] {.async.} =
|
||||
debug "Dialing peer", peerId
|
||||
|
@ -74,9 +74,6 @@ proc dialAndUpgrade(
|
|||
libp2p_failed_dials.inc()
|
||||
continue # Try the next address
|
||||
|
||||
# make sure to assign the peer to the connection
|
||||
dialed.peerId = peerId
|
||||
|
||||
# also keep track of the connection's bottom unsafe transport direction
|
||||
# required by gossipsub scoring
|
||||
dialed.transportDir = Direction.Out
|
||||
|
@ -84,7 +81,7 @@ proc dialAndUpgrade(
|
|||
libp2p_successful_dials.inc()
|
||||
|
||||
let conn = try:
|
||||
await transport.upgradeOutgoing(dialed)
|
||||
await transport.upgradeOutgoing(dialed, peerId)
|
||||
except CatchableError as exc:
|
||||
# If we failed to establish the connection through one transport,
|
||||
# we won't succeeded through another - no use in trying again
|
||||
|
@ -101,20 +98,22 @@ proc dialAndUpgrade(
|
|||
|
||||
proc internalConnect(
|
||||
self: Dialer,
|
||||
peerId: PeerId,
|
||||
peerId: Opt[PeerId],
|
||||
addrs: seq[MultiAddress],
|
||||
forceDial: bool):
|
||||
Future[Connection] {.async.} =
|
||||
if self.localPeerId == peerId:
|
||||
if Opt.some(self.localPeerId) == peerId:
|
||||
raise newException(CatchableError, "can't dial self!")
|
||||
|
||||
# Ensure there's only one in-flight attempt per peer
|
||||
let lock = self.dialLock.mgetOrPut(peerId, newAsyncLock())
|
||||
let lock = self.dialLock.mgetOrPut(peerId.get(default(PeerId)), newAsyncLock())
|
||||
try:
|
||||
await lock.acquire()
|
||||
|
||||
# Check if we have a connection already and try to reuse it
|
||||
var conn = self.connManager.selectConn(peerId)
|
||||
var conn =
|
||||
if peerId.isSome: self.connManager.selectConn(peerId.get())
|
||||
else: nil
|
||||
if conn != nil:
|
||||
if conn.atEof or conn.closed:
|
||||
# This connection should already have been removed from the connection
|
||||
|
@ -165,7 +164,15 @@ method connect*(
|
|||
if self.connManager.connCount(peerId) > 0:
|
||||
return
|
||||
|
||||
discard await self.internalConnect(peerId, addrs, forceDial)
|
||||
discard await self.internalConnect(Opt.some(peerId), addrs, forceDial)
|
||||
|
||||
method connect*(
|
||||
self: Dialer,
|
||||
addrs: seq[MultiAddress],
|
||||
): Future[PeerId] {.async.} =
|
||||
## Connects to a peer and retrieve its PeerId
|
||||
|
||||
return (await self.internalConnect(Opt.none(PeerId), addrs, false)).peerId
|
||||
|
||||
proc negotiateStream(
|
||||
self: Dialer,
|
||||
|
@ -190,7 +197,7 @@ method tryDial*(
|
|||
|
||||
trace "Check if it can dial", peerId, addrs
|
||||
try:
|
||||
let conn = await self.dialAndUpgrade(peerId, addrs)
|
||||
let conn = await self.dialAndUpgrade(Opt.some(peerId), addrs)
|
||||
if conn.isNil():
|
||||
raise newException(DialFailedError, "No valid multiaddress")
|
||||
await conn.close()
|
||||
|
@ -238,7 +245,7 @@ method dial*(
|
|||
|
||||
try:
|
||||
trace "Dialing (new)", peerId, protos
|
||||
conn = await self.internalConnect(peerId, addrs, forceDial)
|
||||
conn = await self.internalConnect(Opt.some(peerId), addrs, forceDial)
|
||||
trace "Opening stream", conn
|
||||
stream = await self.connManager.getStream(conn)
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ const
|
|||
# https://godoc.org/github.com/libp2p/go-libp2p-noise#pkg-constants
|
||||
NoiseCodec* = "/noise"
|
||||
|
||||
PayloadString = "noise-libp2p-static-key:"
|
||||
PayloadString = toBytes("noise-libp2p-static-key:")
|
||||
|
||||
ProtocolXXName = "Noise_XX_25519_ChaChaPoly_SHA256"
|
||||
|
||||
|
@ -339,7 +339,6 @@ proc handshakeXXOutbound(
|
|||
hs = HandshakeState.init()
|
||||
|
||||
try:
|
||||
|
||||
hs.ss.mixHash(p.commonPrologue)
|
||||
hs.s = p.noiseKeys
|
||||
|
||||
|
@ -445,7 +444,6 @@ method readMessage*(sconn: NoiseConnection): Future[seq[byte]] {.async.} =
|
|||
dumpMessage(sconn, FlowDirection.Incoming, [])
|
||||
trace "Received 0-length message", sconn
|
||||
|
||||
|
||||
proc encryptFrame(
|
||||
sconn: NoiseConnection,
|
||||
cipherFrame: var openArray[byte],
|
||||
|
@ -506,7 +504,7 @@ method write*(sconn: NoiseConnection, message: seq[byte]): Future[void] =
|
|||
# sequencing issues
|
||||
sconn.stream.write(cipherFrames)
|
||||
|
||||
method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureConn] {.async.} =
|
||||
method handshake*(p: Noise, conn: Connection, initiator: bool, peerId: Opt[PeerId]): Future[SecureConn] {.async.} =
|
||||
trace "Starting Noise handshake", conn, initiator
|
||||
|
||||
let timeout = conn.timeout
|
||||
|
@ -515,7 +513,7 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon
|
|||
# https://github.com/libp2p/specs/tree/master/noise#libp2p-data-in-handshake-messages
|
||||
let
|
||||
signedPayload = p.localPrivateKey.sign(
|
||||
PayloadString.toBytes & p.noiseKeys.publicKey.getBytes).tryGet()
|
||||
PayloadString & p.noiseKeys.publicKey.getBytes).tryGet()
|
||||
|
||||
var
|
||||
libp2pProof = initProtoBuffer()
|
||||
|
@ -538,11 +536,9 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon
|
|||
remoteSig: Signature
|
||||
remoteSigBytes: seq[byte]
|
||||
|
||||
let r1 = remoteProof.getField(1, remotePubKeyBytes)
|
||||
let r2 = remoteProof.getField(2, remoteSigBytes)
|
||||
if r1.isErr() or not(r1.get()):
|
||||
if not remoteProof.getField(1, remotePubKeyBytes).valueOr(false):
|
||||
raise newException(NoiseHandshakeError, "Failed to deserialize remote public key bytes. (initiator: " & $initiator & ")")
|
||||
if r2.isErr() or not(r2.get()):
|
||||
if not remoteProof.getField(2, remoteSigBytes).valueOr(false):
|
||||
raise newException(NoiseHandshakeError, "Failed to deserialize remote signature bytes. (initiator: " & $initiator & ")")
|
||||
|
||||
if not remotePubKey.init(remotePubKeyBytes):
|
||||
|
@ -550,33 +546,34 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon
|
|||
if not remoteSig.init(remoteSigBytes):
|
||||
raise newException(NoiseHandshakeError, "Failed to decode remote signature. (initiator: " & $initiator & ")")
|
||||
|
||||
let verifyPayload = PayloadString.toBytes & handshakeRes.rs.getBytes
|
||||
let verifyPayload = PayloadString & handshakeRes.rs.getBytes
|
||||
if not remoteSig.verify(verifyPayload, remotePubKey):
|
||||
raise newException(NoiseHandshakeError, "Noise handshake signature verify failed.")
|
||||
else:
|
||||
trace "Remote signature verified", conn
|
||||
|
||||
if initiator:
|
||||
let pid = PeerId.init(remotePubKey)
|
||||
if not conn.peerId.validate():
|
||||
raise newException(NoiseHandshakeError, "Failed to validate peerId.")
|
||||
if pid.isErr or pid.get() != conn.peerId:
|
||||
let pid = PeerId.init(remotePubKey).valueOr:
|
||||
raise newException(NoiseHandshakeError, "Invalid remote peer id: " & $error)
|
||||
|
||||
trace "Remote peer id", pid = $pid
|
||||
|
||||
if peerId.isSome():
|
||||
let targetPid = peerId.get()
|
||||
if not targetPid.validate():
|
||||
raise newException(NoiseHandshakeError, "Failed to validate expected peerId.")
|
||||
|
||||
if pid != targetPid:
|
||||
var
|
||||
failedKey: PublicKey
|
||||
discard extractPublicKey(conn.peerId, failedKey)
|
||||
debug "Noise handshake, peer infos don't match!",
|
||||
discard extractPublicKey(targetPid, failedKey)
|
||||
debug "Noise handshake, peer id doesn't match!",
|
||||
initiator, dealt_peer = conn,
|
||||
dealt_key = $failedKey, received_peer = $pid,
|
||||
received_key = $remotePubKey
|
||||
raise newException(NoiseHandshakeError, "Noise handshake, peer infos don't match! " & $pid & " != " & $conn.peerId)
|
||||
else:
|
||||
let pid = PeerId.init(remotePubKey)
|
||||
if pid.isErr:
|
||||
raise newException(NoiseHandshakeError, "Invalid remote peer id")
|
||||
conn.peerId = pid.get()
|
||||
raise newException(NoiseHandshakeError, "Noise handshake, peer id don't match! " & $pid & " != " & $targetPid)
|
||||
conn.peerId = pid
|
||||
|
||||
var tmp = NoiseConnection.new(conn, conn.peerId, conn.observedAddr)
|
||||
|
||||
if initiator:
|
||||
tmp.readCs = handshakeRes.cs2
|
||||
tmp.writeCs = handshakeRes.cs1
|
||||
|
|
|
@ -291,7 +291,7 @@ proc transactMessage(conn: Connection,
|
|||
await conn.write(msg)
|
||||
return await conn.readRawMessage()
|
||||
|
||||
method handshake*(s: Secio, conn: Connection, initiator: bool = false): Future[SecureConn] {.async.} =
|
||||
method handshake*(s: Secio, conn: Connection, initiator: bool, peerId: Opt[PeerId]): Future[SecureConn] {.async.} =
|
||||
var
|
||||
localNonce: array[SecioNonceSize, byte]
|
||||
remoteNonce: seq[byte]
|
||||
|
@ -342,8 +342,13 @@ method handshake*(s: Secio, conn: Connection, initiator: bool = false): Future[S
|
|||
|
||||
remotePeerId = PeerId.init(remotePubkey).tryGet()
|
||||
|
||||
# TODO: PeerId check against supplied PeerId
|
||||
if not initiator:
|
||||
if peerId.isSome():
|
||||
let targetPid = peerId.get()
|
||||
if not targetPid.validate():
|
||||
raise newException(SecioError, "Failed to validate expected peerId.")
|
||||
|
||||
if remotePeerId != targetPid:
|
||||
raise newException(SecioError, "Peer ids don't match!")
|
||||
conn.peerId = remotePeerId
|
||||
let order = getOrder(remoteBytesPubkey, localNonce, localBytesPubkey,
|
||||
remoteNonce).tryGet()
|
||||
|
|
|
@ -79,13 +79,15 @@ method getWrapped*(s: SecureConn): Connection = s.stream
|
|||
|
||||
method handshake*(s: Secure,
|
||||
conn: Connection,
|
||||
initiator: bool): Future[SecureConn] {.async, base.} =
|
||||
initiator: bool,
|
||||
peerId: Opt[PeerId]): Future[SecureConn] {.async, base.} =
|
||||
doAssert(false, "Not implemented!")
|
||||
|
||||
proc handleConn(s: Secure,
|
||||
conn: Connection,
|
||||
initiator: bool): Future[Connection] {.async.} =
|
||||
var sconn = await s.handshake(conn, initiator)
|
||||
initiator: bool,
|
||||
peerId: Opt[PeerId]): Future[Connection] {.async.} =
|
||||
var sconn = await s.handshake(conn, initiator, peerId)
|
||||
# mark connection bottom level transport direction
|
||||
# this is the safest place to do this
|
||||
# we require this information in for example gossipsub
|
||||
|
@ -121,7 +123,7 @@ method init*(s: Secure) =
|
|||
try:
|
||||
# We don't need the result but we
|
||||
# definitely need to await the handshake
|
||||
discard await s.handleConn(conn, false)
|
||||
discard await s.handleConn(conn, false, Opt.none(PeerId))
|
||||
trace "connection secured", conn
|
||||
except CancelledError as exc:
|
||||
warn "securing connection canceled", conn
|
||||
|
@ -135,9 +137,10 @@ method init*(s: Secure) =
|
|||
|
||||
method secure*(s: Secure,
|
||||
conn: Connection,
|
||||
initiator: bool):
|
||||
initiator: bool,
|
||||
peerId: Opt[PeerId]):
|
||||
Future[Connection] {.base.} =
|
||||
s.handleConn(conn, initiator)
|
||||
s.handleConn(conn, initiator, peerId)
|
||||
|
||||
method readOnce*(s: SecureConn,
|
||||
pbytes: pointer,
|
||||
|
|
|
@ -128,6 +128,13 @@ method connect*(
|
|||
|
||||
s.dialer.connect(peerId, addrs, forceDial)
|
||||
|
||||
method connect*(
|
||||
s: Switch,
|
||||
addrs: seq[MultiAddress]): Future[PeerId] =
|
||||
## Connects to a peer and retrieve its PeerId
|
||||
|
||||
s.dialer.connect(addrs)
|
||||
|
||||
method dial*(
|
||||
s: Switch,
|
||||
peerId: PeerId,
|
||||
|
|
|
@ -87,12 +87,13 @@ method upgradeIncoming*(
|
|||
|
||||
method upgradeOutgoing*(
|
||||
self: Transport,
|
||||
conn: Connection): Future[Connection] {.base, gcsafe.} =
|
||||
conn: Connection,
|
||||
peerId: Opt[PeerId]): Future[Connection] {.base, gcsafe.} =
|
||||
## base upgrade method that the transport uses to perform
|
||||
## transport specific upgrades
|
||||
##
|
||||
|
||||
self.upgrader.upgradeOutgoing(conn)
|
||||
self.upgrader.upgradeOutgoing(conn, peerId)
|
||||
|
||||
method handles*(
|
||||
self: Transport,
|
||||
|
|
|
@ -88,10 +88,11 @@ proc mux*(
|
|||
|
||||
method upgradeOutgoing*(
|
||||
self: MuxedUpgrade,
|
||||
conn: Connection): Future[Connection] {.async, gcsafe.} =
|
||||
conn: Connection,
|
||||
peerId: Opt[PeerId]): Future[Connection] {.async, gcsafe.} =
|
||||
trace "Upgrading outgoing connection", conn
|
||||
|
||||
let sconn = await self.secure(conn) # secure the connection
|
||||
let sconn = await self.secure(conn, peerId) # secure the connection
|
||||
if isNil(sconn):
|
||||
raise newException(UpgradeFailedError,
|
||||
"unable to secure connection, stopping upgrade")
|
||||
|
@ -129,7 +130,7 @@ method upgradeIncoming*(
|
|||
|
||||
var cconn = conn
|
||||
try:
|
||||
var sconn = await secure.secure(cconn, false)
|
||||
var sconn = await secure.secure(cconn, false, Opt.none(PeerId))
|
||||
if isNil(sconn):
|
||||
return
|
||||
|
||||
|
|
|
@ -47,12 +47,14 @@ method upgradeIncoming*(
|
|||
|
||||
method upgradeOutgoing*(
|
||||
self: Upgrade,
|
||||
conn: Connection): Future[Connection] {.base.} =
|
||||
conn: Connection,
|
||||
peerId: Opt[PeerId]): Future[Connection] {.base.} =
|
||||
doAssert(false, "Not implemented!")
|
||||
|
||||
proc secure*(
|
||||
self: Upgrade,
|
||||
conn: Connection): Future[Connection] {.async, gcsafe.} =
|
||||
conn: Connection,
|
||||
peerId: Opt[PeerId]): Future[Connection] {.async, gcsafe.} =
|
||||
if self.secureManagers.len <= 0:
|
||||
raise newException(UpgradeFailedError, "No secure managers registered!")
|
||||
|
||||
|
@ -67,7 +69,7 @@ proc secure*(
|
|||
# let's avoid duplicating checks but detect if it fails to do it properly
|
||||
doAssert(secureProtocol.len > 0)
|
||||
|
||||
return await secureProtocol[0].secure(conn, true)
|
||||
return await secureProtocol[0].secure(conn, true, peerId)
|
||||
|
||||
proc identify*(
|
||||
self: Upgrade,
|
||||
|
|
|
@ -104,7 +104,7 @@ suite "Noise":
|
|||
|
||||
proc acceptHandler() {.async.} =
|
||||
let conn = await transport1.accept()
|
||||
let sconn = await serverNoise.secure(conn, false)
|
||||
let sconn = await serverNoise.secure(conn, false, Opt.none(PeerId))
|
||||
try:
|
||||
await sconn.write("Hello!")
|
||||
finally:
|
||||
|
@ -119,8 +119,7 @@ suite "Noise":
|
|||
clientNoise = Noise.new(rng, clientPrivKey, outgoing = true)
|
||||
conn = await transport2.dial(transport1.addrs[0])
|
||||
|
||||
conn.peerId = serverInfo.peerId
|
||||
let sconn = await clientNoise.secure(conn, true)
|
||||
let sconn = await clientNoise.secure(conn, true, Opt.some(serverInfo.peerId))
|
||||
|
||||
var msg = newSeq[byte](6)
|
||||
await sconn.readExactly(addr msg[0], 6)
|
||||
|
@ -149,7 +148,7 @@ suite "Noise":
|
|||
var conn: Connection
|
||||
try:
|
||||
conn = await transport1.accept()
|
||||
discard await serverNoise.secure(conn, false)
|
||||
discard await serverNoise.secure(conn, false, Opt.none(PeerId))
|
||||
except CatchableError:
|
||||
discard
|
||||
finally:
|
||||
|
@ -162,11 +161,10 @@ suite "Noise":
|
|||
clientInfo = PeerInfo.new(clientPrivKey, transport1.addrs)
|
||||
clientNoise = Noise.new(rng, clientPrivKey, outgoing = true, commonPrologue = @[1'u8, 2'u8, 3'u8])
|
||||
conn = await transport2.dial(transport1.addrs[0])
|
||||
conn.peerId = serverInfo.peerId
|
||||
|
||||
var sconn: Connection = nil
|
||||
expect(NoiseDecryptTagError):
|
||||
sconn = await clientNoise.secure(conn, true)
|
||||
sconn = await clientNoise.secure(conn, true, Opt.some(conn.peerId))
|
||||
|
||||
await conn.close()
|
||||
await handlerWait
|
||||
|
@ -186,7 +184,7 @@ suite "Noise":
|
|||
|
||||
proc acceptHandler() {.async, gcsafe.} =
|
||||
let conn = await transport1.accept()
|
||||
let sconn = await serverNoise.secure(conn, false)
|
||||
let sconn = await serverNoise.secure(conn, false, Opt.none(PeerId))
|
||||
defer:
|
||||
await sconn.close()
|
||||
await conn.close()
|
||||
|
@ -202,8 +200,7 @@ suite "Noise":
|
|||
clientInfo = PeerInfo.new(clientPrivKey, transport1.addrs)
|
||||
clientNoise = Noise.new(rng, clientPrivKey, outgoing = true)
|
||||
conn = await transport2.dial(transport1.addrs[0])
|
||||
conn.peerId = serverInfo.peerId
|
||||
let sconn = await clientNoise.secure(conn, true)
|
||||
let sconn = await clientNoise.secure(conn, true, Opt.some(serverInfo.peerId))
|
||||
|
||||
await sconn.write("Hello!")
|
||||
await acceptFut
|
||||
|
@ -230,7 +227,7 @@ suite "Noise":
|
|||
|
||||
proc acceptHandler() {.async, gcsafe.} =
|
||||
let conn = await transport1.accept()
|
||||
let sconn = await serverNoise.secure(conn, false)
|
||||
let sconn = await serverNoise.secure(conn, false, Opt.none(PeerId))
|
||||
defer:
|
||||
await sconn.close()
|
||||
let msg = await sconn.readLp(1024*1024)
|
||||
|
@ -244,8 +241,7 @@ suite "Noise":
|
|||
clientInfo = PeerInfo.new(clientPrivKey, transport1.addrs)
|
||||
clientNoise = Noise.new(rng, clientPrivKey, outgoing = true)
|
||||
conn = await transport2.dial(transport1.addrs[0])
|
||||
conn.peerId = serverInfo.peerId
|
||||
let sconn = await clientNoise.secure(conn, true)
|
||||
let sconn = await clientNoise.secure(conn, true, Opt.some(serverInfo.peerId))
|
||||
|
||||
await sconn.writeLp(hugePayload)
|
||||
await readTask
|
||||
|
|
|
@ -201,6 +201,20 @@ suite "Switch":
|
|||
check not switch1.isConnected(switch2.peerInfo.peerId)
|
||||
check not switch2.isConnected(switch1.peerInfo.peerId)
|
||||
|
||||
asyncTest "e2e connect to peer with unkown PeerId":
|
||||
let switch1 = newStandardSwitch(secureManagers = [SecureProtocol.Noise])
|
||||
let switch2 = newStandardSwitch(secureManagers = [SecureProtocol.Noise])
|
||||
await switch1.start()
|
||||
await switch2.start()
|
||||
|
||||
check: (await switch2.connect(switch1.peerInfo.addrs)) == switch1.peerInfo.peerId
|
||||
await switch2.disconnect(switch1.peerInfo.peerId)
|
||||
|
||||
await allFuturesThrowing(
|
||||
switch1.stop(),
|
||||
switch2.stop()
|
||||
)
|
||||
|
||||
asyncTest "e2e should not leak on peer disconnect":
|
||||
let switch1 = newStandardSwitch()
|
||||
let switch2 = newStandardSwitch()
|
||||
|
|
Loading…
Reference in New Issue