mirror of https://github.com/vacp2p/nim-libp2p.git
allow connection to a peer with unknown PeerId (#756)
Co-authored-by: Tanguy <tanguy@status.im>
This commit is contained in:
parent
1de7508b64
commit
dfbfbe6eb6
|
@ -31,7 +31,6 @@ const
|
||||||
type
|
type
|
||||||
Curve25519* = object
|
Curve25519* = object
|
||||||
Curve25519Key* = array[Curve25519KeySize, byte]
|
Curve25519Key* = array[Curve25519KeySize, byte]
|
||||||
pcuchar = ptr char
|
|
||||||
Curve25519Error* = enum
|
Curve25519Error* = enum
|
||||||
Curver25519GenError
|
Curver25519GenError
|
||||||
|
|
||||||
|
@ -77,7 +76,7 @@ proc mulgen(_: type[Curve25519], dst: var Curve25519Key, point: Curve25519Key) =
|
||||||
addr rpoint[0],
|
addr rpoint[0],
|
||||||
Curve25519KeySize,
|
Curve25519KeySize,
|
||||||
EC_curve25519)
|
EC_curve25519)
|
||||||
|
|
||||||
assert size == Curve25519KeySize
|
assert size == Curve25519KeySize
|
||||||
|
|
||||||
proc public*(private: Curve25519Key): Curve25519Key =
|
proc public*(private: Curve25519Key): Curve25519Key =
|
||||||
|
|
|
@ -31,6 +31,13 @@ method connect*(
|
||||||
|
|
||||||
doAssert(false, "Not implemented!")
|
doAssert(false, "Not implemented!")
|
||||||
|
|
||||||
|
method connect*(
|
||||||
|
self: Dial,
|
||||||
|
addrs: seq[MultiAddress]): Future[PeerId] {.async, base.} =
|
||||||
|
## Connects to a peer and retrieve its PeerId
|
||||||
|
|
||||||
|
doAssert(false, "Not implemented!")
|
||||||
|
|
||||||
method dial*(
|
method dial*(
|
||||||
self: Dial,
|
self: Dial,
|
||||||
peerId: PeerId,
|
peerId: PeerId,
|
||||||
|
|
|
@ -47,7 +47,7 @@ type
|
||||||
|
|
||||||
proc dialAndUpgrade(
|
proc dialAndUpgrade(
|
||||||
self: Dialer,
|
self: Dialer,
|
||||||
peerId: PeerId,
|
peerId: Opt[PeerId],
|
||||||
addrs: seq[MultiAddress]):
|
addrs: seq[MultiAddress]):
|
||||||
Future[Connection] {.async.} =
|
Future[Connection] {.async.} =
|
||||||
debug "Dialing peer", peerId
|
debug "Dialing peer", peerId
|
||||||
|
@ -74,9 +74,6 @@ proc dialAndUpgrade(
|
||||||
libp2p_failed_dials.inc()
|
libp2p_failed_dials.inc()
|
||||||
continue # Try the next address
|
continue # Try the next address
|
||||||
|
|
||||||
# make sure to assign the peer to the connection
|
|
||||||
dialed.peerId = peerId
|
|
||||||
|
|
||||||
# also keep track of the connection's bottom unsafe transport direction
|
# also keep track of the connection's bottom unsafe transport direction
|
||||||
# required by gossipsub scoring
|
# required by gossipsub scoring
|
||||||
dialed.transportDir = Direction.Out
|
dialed.transportDir = Direction.Out
|
||||||
|
@ -84,7 +81,7 @@ proc dialAndUpgrade(
|
||||||
libp2p_successful_dials.inc()
|
libp2p_successful_dials.inc()
|
||||||
|
|
||||||
let conn = try:
|
let conn = try:
|
||||||
await transport.upgradeOutgoing(dialed)
|
await transport.upgradeOutgoing(dialed, peerId)
|
||||||
except CatchableError as exc:
|
except CatchableError as exc:
|
||||||
# If we failed to establish the connection through one transport,
|
# If we failed to establish the connection through one transport,
|
||||||
# we won't succeeded through another - no use in trying again
|
# we won't succeeded through another - no use in trying again
|
||||||
|
@ -101,20 +98,22 @@ proc dialAndUpgrade(
|
||||||
|
|
||||||
proc internalConnect(
|
proc internalConnect(
|
||||||
self: Dialer,
|
self: Dialer,
|
||||||
peerId: PeerId,
|
peerId: Opt[PeerId],
|
||||||
addrs: seq[MultiAddress],
|
addrs: seq[MultiAddress],
|
||||||
forceDial: bool):
|
forceDial: bool):
|
||||||
Future[Connection] {.async.} =
|
Future[Connection] {.async.} =
|
||||||
if self.localPeerId == peerId:
|
if Opt.some(self.localPeerId) == peerId:
|
||||||
raise newException(CatchableError, "can't dial self!")
|
raise newException(CatchableError, "can't dial self!")
|
||||||
|
|
||||||
# Ensure there's only one in-flight attempt per peer
|
# Ensure there's only one in-flight attempt per peer
|
||||||
let lock = self.dialLock.mgetOrPut(peerId, newAsyncLock())
|
let lock = self.dialLock.mgetOrPut(peerId.get(default(PeerId)), newAsyncLock())
|
||||||
try:
|
try:
|
||||||
await lock.acquire()
|
await lock.acquire()
|
||||||
|
|
||||||
# Check if we have a connection already and try to reuse it
|
# Check if we have a connection already and try to reuse it
|
||||||
var conn = self.connManager.selectConn(peerId)
|
var conn =
|
||||||
|
if peerId.isSome: self.connManager.selectConn(peerId.get())
|
||||||
|
else: nil
|
||||||
if conn != nil:
|
if conn != nil:
|
||||||
if conn.atEof or conn.closed:
|
if conn.atEof or conn.closed:
|
||||||
# This connection should already have been removed from the connection
|
# This connection should already have been removed from the connection
|
||||||
|
@ -165,7 +164,15 @@ method connect*(
|
||||||
if self.connManager.connCount(peerId) > 0:
|
if self.connManager.connCount(peerId) > 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
discard await self.internalConnect(peerId, addrs, forceDial)
|
discard await self.internalConnect(Opt.some(peerId), addrs, forceDial)
|
||||||
|
|
||||||
|
method connect*(
|
||||||
|
self: Dialer,
|
||||||
|
addrs: seq[MultiAddress],
|
||||||
|
): Future[PeerId] {.async.} =
|
||||||
|
## Connects to a peer and retrieve its PeerId
|
||||||
|
|
||||||
|
return (await self.internalConnect(Opt.none(PeerId), addrs, false)).peerId
|
||||||
|
|
||||||
proc negotiateStream(
|
proc negotiateStream(
|
||||||
self: Dialer,
|
self: Dialer,
|
||||||
|
@ -190,7 +197,7 @@ method tryDial*(
|
||||||
|
|
||||||
trace "Check if it can dial", peerId, addrs
|
trace "Check if it can dial", peerId, addrs
|
||||||
try:
|
try:
|
||||||
let conn = await self.dialAndUpgrade(peerId, addrs)
|
let conn = await self.dialAndUpgrade(Opt.some(peerId), addrs)
|
||||||
if conn.isNil():
|
if conn.isNil():
|
||||||
raise newException(DialFailedError, "No valid multiaddress")
|
raise newException(DialFailedError, "No valid multiaddress")
|
||||||
await conn.close()
|
await conn.close()
|
||||||
|
@ -238,7 +245,7 @@ method dial*(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
trace "Dialing (new)", peerId, protos
|
trace "Dialing (new)", peerId, protos
|
||||||
conn = await self.internalConnect(peerId, addrs, forceDial)
|
conn = await self.internalConnect(Opt.some(peerId), addrs, forceDial)
|
||||||
trace "Opening stream", conn
|
trace "Opening stream", conn
|
||||||
stream = await self.connManager.getStream(conn)
|
stream = await self.connManager.getStream(conn)
|
||||||
|
|
||||||
|
|
|
@ -43,7 +43,7 @@ func shortLog*(pid: PeerId): string =
|
||||||
var spid = $pid
|
var spid = $pid
|
||||||
if len(spid) > 10:
|
if len(spid) > 10:
|
||||||
spid[3] = '*'
|
spid[3] = '*'
|
||||||
|
|
||||||
when (NimMajor, NimMinor) > (1, 4):
|
when (NimMajor, NimMinor) > (1, 4):
|
||||||
spid.delete(4 .. spid.high - 6)
|
spid.delete(4 .. spid.high - 6)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -38,7 +38,7 @@ const
|
||||||
# https://godoc.org/github.com/libp2p/go-libp2p-noise#pkg-constants
|
# https://godoc.org/github.com/libp2p/go-libp2p-noise#pkg-constants
|
||||||
NoiseCodec* = "/noise"
|
NoiseCodec* = "/noise"
|
||||||
|
|
||||||
PayloadString = "noise-libp2p-static-key:"
|
PayloadString = toBytes("noise-libp2p-static-key:")
|
||||||
|
|
||||||
ProtocolXXName = "Noise_XX_25519_ChaChaPoly_SHA256"
|
ProtocolXXName = "Noise_XX_25519_ChaChaPoly_SHA256"
|
||||||
|
|
||||||
|
@ -339,7 +339,6 @@ proc handshakeXXOutbound(
|
||||||
hs = HandshakeState.init()
|
hs = HandshakeState.init()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
hs.ss.mixHash(p.commonPrologue)
|
hs.ss.mixHash(p.commonPrologue)
|
||||||
hs.s = p.noiseKeys
|
hs.s = p.noiseKeys
|
||||||
|
|
||||||
|
@ -445,7 +444,6 @@ method readMessage*(sconn: NoiseConnection): Future[seq[byte]] {.async.} =
|
||||||
dumpMessage(sconn, FlowDirection.Incoming, [])
|
dumpMessage(sconn, FlowDirection.Incoming, [])
|
||||||
trace "Received 0-length message", sconn
|
trace "Received 0-length message", sconn
|
||||||
|
|
||||||
|
|
||||||
proc encryptFrame(
|
proc encryptFrame(
|
||||||
sconn: NoiseConnection,
|
sconn: NoiseConnection,
|
||||||
cipherFrame: var openArray[byte],
|
cipherFrame: var openArray[byte],
|
||||||
|
@ -506,7 +504,7 @@ method write*(sconn: NoiseConnection, message: seq[byte]): Future[void] =
|
||||||
# sequencing issues
|
# sequencing issues
|
||||||
sconn.stream.write(cipherFrames)
|
sconn.stream.write(cipherFrames)
|
||||||
|
|
||||||
method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureConn] {.async.} =
|
method handshake*(p: Noise, conn: Connection, initiator: bool, peerId: Opt[PeerId]): Future[SecureConn] {.async.} =
|
||||||
trace "Starting Noise handshake", conn, initiator
|
trace "Starting Noise handshake", conn, initiator
|
||||||
|
|
||||||
let timeout = conn.timeout
|
let timeout = conn.timeout
|
||||||
|
@ -515,7 +513,7 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon
|
||||||
# https://github.com/libp2p/specs/tree/master/noise#libp2p-data-in-handshake-messages
|
# https://github.com/libp2p/specs/tree/master/noise#libp2p-data-in-handshake-messages
|
||||||
let
|
let
|
||||||
signedPayload = p.localPrivateKey.sign(
|
signedPayload = p.localPrivateKey.sign(
|
||||||
PayloadString.toBytes & p.noiseKeys.publicKey.getBytes).tryGet()
|
PayloadString & p.noiseKeys.publicKey.getBytes).tryGet()
|
||||||
|
|
||||||
var
|
var
|
||||||
libp2pProof = initProtoBuffer()
|
libp2pProof = initProtoBuffer()
|
||||||
|
@ -538,11 +536,9 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon
|
||||||
remoteSig: Signature
|
remoteSig: Signature
|
||||||
remoteSigBytes: seq[byte]
|
remoteSigBytes: seq[byte]
|
||||||
|
|
||||||
let r1 = remoteProof.getField(1, remotePubKeyBytes)
|
if not remoteProof.getField(1, remotePubKeyBytes).valueOr(false):
|
||||||
let r2 = remoteProof.getField(2, remoteSigBytes)
|
|
||||||
if r1.isErr() or not(r1.get()):
|
|
||||||
raise newException(NoiseHandshakeError, "Failed to deserialize remote public key bytes. (initiator: " & $initiator & ")")
|
raise newException(NoiseHandshakeError, "Failed to deserialize remote public key bytes. (initiator: " & $initiator & ")")
|
||||||
if r2.isErr() or not(r2.get()):
|
if not remoteProof.getField(2, remoteSigBytes).valueOr(false):
|
||||||
raise newException(NoiseHandshakeError, "Failed to deserialize remote signature bytes. (initiator: " & $initiator & ")")
|
raise newException(NoiseHandshakeError, "Failed to deserialize remote signature bytes. (initiator: " & $initiator & ")")
|
||||||
|
|
||||||
if not remotePubKey.init(remotePubKeyBytes):
|
if not remotePubKey.init(remotePubKeyBytes):
|
||||||
|
@ -550,33 +546,34 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon
|
||||||
if not remoteSig.init(remoteSigBytes):
|
if not remoteSig.init(remoteSigBytes):
|
||||||
raise newException(NoiseHandshakeError, "Failed to decode remote signature. (initiator: " & $initiator & ")")
|
raise newException(NoiseHandshakeError, "Failed to decode remote signature. (initiator: " & $initiator & ")")
|
||||||
|
|
||||||
let verifyPayload = PayloadString.toBytes & handshakeRes.rs.getBytes
|
let verifyPayload = PayloadString & handshakeRes.rs.getBytes
|
||||||
if not remoteSig.verify(verifyPayload, remotePubKey):
|
if not remoteSig.verify(verifyPayload, remotePubKey):
|
||||||
raise newException(NoiseHandshakeError, "Noise handshake signature verify failed.")
|
raise newException(NoiseHandshakeError, "Noise handshake signature verify failed.")
|
||||||
else:
|
else:
|
||||||
trace "Remote signature verified", conn
|
trace "Remote signature verified", conn
|
||||||
|
|
||||||
if initiator:
|
let pid = PeerId.init(remotePubKey).valueOr:
|
||||||
let pid = PeerId.init(remotePubKey)
|
raise newException(NoiseHandshakeError, "Invalid remote peer id: " & $error)
|
||||||
if not conn.peerId.validate():
|
|
||||||
raise newException(NoiseHandshakeError, "Failed to validate peerId.")
|
trace "Remote peer id", pid = $pid
|
||||||
if pid.isErr or pid.get() != conn.peerId:
|
|
||||||
|
if peerId.isSome():
|
||||||
|
let targetPid = peerId.get()
|
||||||
|
if not targetPid.validate():
|
||||||
|
raise newException(NoiseHandshakeError, "Failed to validate expected peerId.")
|
||||||
|
|
||||||
|
if pid != targetPid:
|
||||||
var
|
var
|
||||||
failedKey: PublicKey
|
failedKey: PublicKey
|
||||||
discard extractPublicKey(conn.peerId, failedKey)
|
discard extractPublicKey(targetPid, failedKey)
|
||||||
debug "Noise handshake, peer infos don't match!",
|
debug "Noise handshake, peer id doesn't match!",
|
||||||
initiator, dealt_peer = conn,
|
initiator, dealt_peer = conn,
|
||||||
dealt_key = $failedKey, received_peer = $pid,
|
dealt_key = $failedKey, received_peer = $pid,
|
||||||
received_key = $remotePubKey
|
received_key = $remotePubKey
|
||||||
raise newException(NoiseHandshakeError, "Noise handshake, peer infos don't match! " & $pid & " != " & $conn.peerId)
|
raise newException(NoiseHandshakeError, "Noise handshake, peer id don't match! " & $pid & " != " & $targetPid)
|
||||||
else:
|
conn.peerId = pid
|
||||||
let pid = PeerId.init(remotePubKey)
|
|
||||||
if pid.isErr:
|
|
||||||
raise newException(NoiseHandshakeError, "Invalid remote peer id")
|
|
||||||
conn.peerId = pid.get()
|
|
||||||
|
|
||||||
var tmp = NoiseConnection.new(conn, conn.peerId, conn.observedAddr)
|
var tmp = NoiseConnection.new(conn, conn.peerId, conn.observedAddr)
|
||||||
|
|
||||||
if initiator:
|
if initiator:
|
||||||
tmp.readCs = handshakeRes.cs2
|
tmp.readCs = handshakeRes.cs2
|
||||||
tmp.writeCs = handshakeRes.cs1
|
tmp.writeCs = handshakeRes.cs1
|
||||||
|
|
|
@ -291,7 +291,7 @@ proc transactMessage(conn: Connection,
|
||||||
await conn.write(msg)
|
await conn.write(msg)
|
||||||
return await conn.readRawMessage()
|
return await conn.readRawMessage()
|
||||||
|
|
||||||
method handshake*(s: Secio, conn: Connection, initiator: bool = false): Future[SecureConn] {.async.} =
|
method handshake*(s: Secio, conn: Connection, initiator: bool, peerId: Opt[PeerId]): Future[SecureConn] {.async.} =
|
||||||
var
|
var
|
||||||
localNonce: array[SecioNonceSize, byte]
|
localNonce: array[SecioNonceSize, byte]
|
||||||
remoteNonce: seq[byte]
|
remoteNonce: seq[byte]
|
||||||
|
@ -342,9 +342,14 @@ method handshake*(s: Secio, conn: Connection, initiator: bool = false): Future[S
|
||||||
|
|
||||||
remotePeerId = PeerId.init(remotePubkey).tryGet()
|
remotePeerId = PeerId.init(remotePubkey).tryGet()
|
||||||
|
|
||||||
# TODO: PeerId check against supplied PeerId
|
if peerId.isSome():
|
||||||
if not initiator:
|
let targetPid = peerId.get()
|
||||||
conn.peerId = remotePeerId
|
if not targetPid.validate():
|
||||||
|
raise newException(SecioError, "Failed to validate expected peerId.")
|
||||||
|
|
||||||
|
if remotePeerId != targetPid:
|
||||||
|
raise newException(SecioError, "Peer ids don't match!")
|
||||||
|
conn.peerId = remotePeerId
|
||||||
let order = getOrder(remoteBytesPubkey, localNonce, localBytesPubkey,
|
let order = getOrder(remoteBytesPubkey, localNonce, localBytesPubkey,
|
||||||
remoteNonce).tryGet()
|
remoteNonce).tryGet()
|
||||||
trace "Remote proposal", schemes = remoteExchanges, ciphers = remoteCiphers,
|
trace "Remote proposal", schemes = remoteExchanges, ciphers = remoteCiphers,
|
||||||
|
|
|
@ -79,13 +79,15 @@ method getWrapped*(s: SecureConn): Connection = s.stream
|
||||||
|
|
||||||
method handshake*(s: Secure,
|
method handshake*(s: Secure,
|
||||||
conn: Connection,
|
conn: Connection,
|
||||||
initiator: bool): Future[SecureConn] {.async, base.} =
|
initiator: bool,
|
||||||
|
peerId: Opt[PeerId]): Future[SecureConn] {.async, base.} =
|
||||||
doAssert(false, "Not implemented!")
|
doAssert(false, "Not implemented!")
|
||||||
|
|
||||||
proc handleConn(s: Secure,
|
proc handleConn(s: Secure,
|
||||||
conn: Connection,
|
conn: Connection,
|
||||||
initiator: bool): Future[Connection] {.async.} =
|
initiator: bool,
|
||||||
var sconn = await s.handshake(conn, initiator)
|
peerId: Opt[PeerId]): Future[Connection] {.async.} =
|
||||||
|
var sconn = await s.handshake(conn, initiator, peerId)
|
||||||
# mark connection bottom level transport direction
|
# mark connection bottom level transport direction
|
||||||
# this is the safest place to do this
|
# this is the safest place to do this
|
||||||
# we require this information in for example gossipsub
|
# we require this information in for example gossipsub
|
||||||
|
@ -121,7 +123,7 @@ method init*(s: Secure) =
|
||||||
try:
|
try:
|
||||||
# We don't need the result but we
|
# We don't need the result but we
|
||||||
# definitely need to await the handshake
|
# definitely need to await the handshake
|
||||||
discard await s.handleConn(conn, false)
|
discard await s.handleConn(conn, false, Opt.none(PeerId))
|
||||||
trace "connection secured", conn
|
trace "connection secured", conn
|
||||||
except CancelledError as exc:
|
except CancelledError as exc:
|
||||||
warn "securing connection canceled", conn
|
warn "securing connection canceled", conn
|
||||||
|
@ -135,9 +137,10 @@ method init*(s: Secure) =
|
||||||
|
|
||||||
method secure*(s: Secure,
|
method secure*(s: Secure,
|
||||||
conn: Connection,
|
conn: Connection,
|
||||||
initiator: bool):
|
initiator: bool,
|
||||||
|
peerId: Opt[PeerId]):
|
||||||
Future[Connection] {.base.} =
|
Future[Connection] {.base.} =
|
||||||
s.handleConn(conn, initiator)
|
s.handleConn(conn, initiator, peerId)
|
||||||
|
|
||||||
method readOnce*(s: SecureConn,
|
method readOnce*(s: SecureConn,
|
||||||
pbytes: pointer,
|
pbytes: pointer,
|
||||||
|
|
|
@ -128,6 +128,13 @@ method connect*(
|
||||||
|
|
||||||
s.dialer.connect(peerId, addrs, forceDial)
|
s.dialer.connect(peerId, addrs, forceDial)
|
||||||
|
|
||||||
|
method connect*(
|
||||||
|
s: Switch,
|
||||||
|
addrs: seq[MultiAddress]): Future[PeerId] =
|
||||||
|
## Connects to a peer and retrieve its PeerId
|
||||||
|
|
||||||
|
s.dialer.connect(addrs)
|
||||||
|
|
||||||
method dial*(
|
method dial*(
|
||||||
s: Switch,
|
s: Switch,
|
||||||
peerId: PeerId,
|
peerId: PeerId,
|
||||||
|
|
|
@ -87,12 +87,13 @@ method upgradeIncoming*(
|
||||||
|
|
||||||
method upgradeOutgoing*(
|
method upgradeOutgoing*(
|
||||||
self: Transport,
|
self: Transport,
|
||||||
conn: Connection): Future[Connection] {.base, gcsafe.} =
|
conn: Connection,
|
||||||
|
peerId: Opt[PeerId]): Future[Connection] {.base, gcsafe.} =
|
||||||
## base upgrade method that the transport uses to perform
|
## base upgrade method that the transport uses to perform
|
||||||
## transport specific upgrades
|
## transport specific upgrades
|
||||||
##
|
##
|
||||||
|
|
||||||
self.upgrader.upgradeOutgoing(conn)
|
self.upgrader.upgradeOutgoing(conn, peerId)
|
||||||
|
|
||||||
method handles*(
|
method handles*(
|
||||||
self: Transport,
|
self: Transport,
|
||||||
|
|
|
@ -88,10 +88,11 @@ proc mux*(
|
||||||
|
|
||||||
method upgradeOutgoing*(
|
method upgradeOutgoing*(
|
||||||
self: MuxedUpgrade,
|
self: MuxedUpgrade,
|
||||||
conn: Connection): Future[Connection] {.async, gcsafe.} =
|
conn: Connection,
|
||||||
|
peerId: Opt[PeerId]): Future[Connection] {.async, gcsafe.} =
|
||||||
trace "Upgrading outgoing connection", conn
|
trace "Upgrading outgoing connection", conn
|
||||||
|
|
||||||
let sconn = await self.secure(conn) # secure the connection
|
let sconn = await self.secure(conn, peerId) # secure the connection
|
||||||
if isNil(sconn):
|
if isNil(sconn):
|
||||||
raise newException(UpgradeFailedError,
|
raise newException(UpgradeFailedError,
|
||||||
"unable to secure connection, stopping upgrade")
|
"unable to secure connection, stopping upgrade")
|
||||||
|
@ -129,7 +130,7 @@ method upgradeIncoming*(
|
||||||
|
|
||||||
var cconn = conn
|
var cconn = conn
|
||||||
try:
|
try:
|
||||||
var sconn = await secure.secure(cconn, false)
|
var sconn = await secure.secure(cconn, false, Opt.none(PeerId))
|
||||||
if isNil(sconn):
|
if isNil(sconn):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -47,12 +47,14 @@ method upgradeIncoming*(
|
||||||
|
|
||||||
method upgradeOutgoing*(
|
method upgradeOutgoing*(
|
||||||
self: Upgrade,
|
self: Upgrade,
|
||||||
conn: Connection): Future[Connection] {.base.} =
|
conn: Connection,
|
||||||
|
peerId: Opt[PeerId]): Future[Connection] {.base.} =
|
||||||
doAssert(false, "Not implemented!")
|
doAssert(false, "Not implemented!")
|
||||||
|
|
||||||
proc secure*(
|
proc secure*(
|
||||||
self: Upgrade,
|
self: Upgrade,
|
||||||
conn: Connection): Future[Connection] {.async, gcsafe.} =
|
conn: Connection,
|
||||||
|
peerId: Opt[PeerId]): Future[Connection] {.async, gcsafe.} =
|
||||||
if self.secureManagers.len <= 0:
|
if self.secureManagers.len <= 0:
|
||||||
raise newException(UpgradeFailedError, "No secure managers registered!")
|
raise newException(UpgradeFailedError, "No secure managers registered!")
|
||||||
|
|
||||||
|
@ -67,7 +69,7 @@ proc secure*(
|
||||||
# let's avoid duplicating checks but detect if it fails to do it properly
|
# let's avoid duplicating checks but detect if it fails to do it properly
|
||||||
doAssert(secureProtocol.len > 0)
|
doAssert(secureProtocol.len > 0)
|
||||||
|
|
||||||
return await secureProtocol[0].secure(conn, true)
|
return await secureProtocol[0].secure(conn, true, peerId)
|
||||||
|
|
||||||
proc identify*(
|
proc identify*(
|
||||||
self: Upgrade,
|
self: Upgrade,
|
||||||
|
|
|
@ -104,7 +104,7 @@ suite "Noise":
|
||||||
|
|
||||||
proc acceptHandler() {.async.} =
|
proc acceptHandler() {.async.} =
|
||||||
let conn = await transport1.accept()
|
let conn = await transport1.accept()
|
||||||
let sconn = await serverNoise.secure(conn, false)
|
let sconn = await serverNoise.secure(conn, false, Opt.none(PeerId))
|
||||||
try:
|
try:
|
||||||
await sconn.write("Hello!")
|
await sconn.write("Hello!")
|
||||||
finally:
|
finally:
|
||||||
|
@ -119,8 +119,7 @@ suite "Noise":
|
||||||
clientNoise = Noise.new(rng, clientPrivKey, outgoing = true)
|
clientNoise = Noise.new(rng, clientPrivKey, outgoing = true)
|
||||||
conn = await transport2.dial(transport1.addrs[0])
|
conn = await transport2.dial(transport1.addrs[0])
|
||||||
|
|
||||||
conn.peerId = serverInfo.peerId
|
let sconn = await clientNoise.secure(conn, true, Opt.some(serverInfo.peerId))
|
||||||
let sconn = await clientNoise.secure(conn, true)
|
|
||||||
|
|
||||||
var msg = newSeq[byte](6)
|
var msg = newSeq[byte](6)
|
||||||
await sconn.readExactly(addr msg[0], 6)
|
await sconn.readExactly(addr msg[0], 6)
|
||||||
|
@ -149,7 +148,7 @@ suite "Noise":
|
||||||
var conn: Connection
|
var conn: Connection
|
||||||
try:
|
try:
|
||||||
conn = await transport1.accept()
|
conn = await transport1.accept()
|
||||||
discard await serverNoise.secure(conn, false)
|
discard await serverNoise.secure(conn, false, Opt.none(PeerId))
|
||||||
except CatchableError:
|
except CatchableError:
|
||||||
discard
|
discard
|
||||||
finally:
|
finally:
|
||||||
|
@ -162,11 +161,10 @@ suite "Noise":
|
||||||
clientInfo = PeerInfo.new(clientPrivKey, transport1.addrs)
|
clientInfo = PeerInfo.new(clientPrivKey, transport1.addrs)
|
||||||
clientNoise = Noise.new(rng, clientPrivKey, outgoing = true, commonPrologue = @[1'u8, 2'u8, 3'u8])
|
clientNoise = Noise.new(rng, clientPrivKey, outgoing = true, commonPrologue = @[1'u8, 2'u8, 3'u8])
|
||||||
conn = await transport2.dial(transport1.addrs[0])
|
conn = await transport2.dial(transport1.addrs[0])
|
||||||
conn.peerId = serverInfo.peerId
|
|
||||||
|
|
||||||
var sconn: Connection = nil
|
var sconn: Connection = nil
|
||||||
expect(NoiseDecryptTagError):
|
expect(NoiseDecryptTagError):
|
||||||
sconn = await clientNoise.secure(conn, true)
|
sconn = await clientNoise.secure(conn, true, Opt.some(conn.peerId))
|
||||||
|
|
||||||
await conn.close()
|
await conn.close()
|
||||||
await handlerWait
|
await handlerWait
|
||||||
|
@ -186,7 +184,7 @@ suite "Noise":
|
||||||
|
|
||||||
proc acceptHandler() {.async, gcsafe.} =
|
proc acceptHandler() {.async, gcsafe.} =
|
||||||
let conn = await transport1.accept()
|
let conn = await transport1.accept()
|
||||||
let sconn = await serverNoise.secure(conn, false)
|
let sconn = await serverNoise.secure(conn, false, Opt.none(PeerId))
|
||||||
defer:
|
defer:
|
||||||
await sconn.close()
|
await sconn.close()
|
||||||
await conn.close()
|
await conn.close()
|
||||||
|
@ -202,8 +200,7 @@ suite "Noise":
|
||||||
clientInfo = PeerInfo.new(clientPrivKey, transport1.addrs)
|
clientInfo = PeerInfo.new(clientPrivKey, transport1.addrs)
|
||||||
clientNoise = Noise.new(rng, clientPrivKey, outgoing = true)
|
clientNoise = Noise.new(rng, clientPrivKey, outgoing = true)
|
||||||
conn = await transport2.dial(transport1.addrs[0])
|
conn = await transport2.dial(transport1.addrs[0])
|
||||||
conn.peerId = serverInfo.peerId
|
let sconn = await clientNoise.secure(conn, true, Opt.some(serverInfo.peerId))
|
||||||
let sconn = await clientNoise.secure(conn, true)
|
|
||||||
|
|
||||||
await sconn.write("Hello!")
|
await sconn.write("Hello!")
|
||||||
await acceptFut
|
await acceptFut
|
||||||
|
@ -230,7 +227,7 @@ suite "Noise":
|
||||||
|
|
||||||
proc acceptHandler() {.async, gcsafe.} =
|
proc acceptHandler() {.async, gcsafe.} =
|
||||||
let conn = await transport1.accept()
|
let conn = await transport1.accept()
|
||||||
let sconn = await serverNoise.secure(conn, false)
|
let sconn = await serverNoise.secure(conn, false, Opt.none(PeerId))
|
||||||
defer:
|
defer:
|
||||||
await sconn.close()
|
await sconn.close()
|
||||||
let msg = await sconn.readLp(1024*1024)
|
let msg = await sconn.readLp(1024*1024)
|
||||||
|
@ -244,8 +241,7 @@ suite "Noise":
|
||||||
clientInfo = PeerInfo.new(clientPrivKey, transport1.addrs)
|
clientInfo = PeerInfo.new(clientPrivKey, transport1.addrs)
|
||||||
clientNoise = Noise.new(rng, clientPrivKey, outgoing = true)
|
clientNoise = Noise.new(rng, clientPrivKey, outgoing = true)
|
||||||
conn = await transport2.dial(transport1.addrs[0])
|
conn = await transport2.dial(transport1.addrs[0])
|
||||||
conn.peerId = serverInfo.peerId
|
let sconn = await clientNoise.secure(conn, true, Opt.some(serverInfo.peerId))
|
||||||
let sconn = await clientNoise.secure(conn, true)
|
|
||||||
|
|
||||||
await sconn.writeLp(hugePayload)
|
await sconn.writeLp(hugePayload)
|
||||||
await readTask
|
await readTask
|
||||||
|
|
|
@ -201,6 +201,20 @@ suite "Switch":
|
||||||
check not switch1.isConnected(switch2.peerInfo.peerId)
|
check not switch1.isConnected(switch2.peerInfo.peerId)
|
||||||
check not switch2.isConnected(switch1.peerInfo.peerId)
|
check not switch2.isConnected(switch1.peerInfo.peerId)
|
||||||
|
|
||||||
|
asyncTest "e2e connect to peer with unkown PeerId":
|
||||||
|
let switch1 = newStandardSwitch(secureManagers = [SecureProtocol.Noise])
|
||||||
|
let switch2 = newStandardSwitch(secureManagers = [SecureProtocol.Noise])
|
||||||
|
await switch1.start()
|
||||||
|
await switch2.start()
|
||||||
|
|
||||||
|
check: (await switch2.connect(switch1.peerInfo.addrs)) == switch1.peerInfo.peerId
|
||||||
|
await switch2.disconnect(switch1.peerInfo.peerId)
|
||||||
|
|
||||||
|
await allFuturesThrowing(
|
||||||
|
switch1.stop(),
|
||||||
|
switch2.stop()
|
||||||
|
)
|
||||||
|
|
||||||
asyncTest "e2e should not leak on peer disconnect":
|
asyncTest "e2e should not leak on peer disconnect":
|
||||||
let switch1 = newStandardSwitch()
|
let switch1 = newStandardSwitch()
|
||||||
let switch2 = newStandardSwitch()
|
let switch2 = newStandardSwitch()
|
||||||
|
|
Loading…
Reference in New Issue