diff --git a/libp2p/muxers/mplex/mplex.nim b/libp2p/muxers/mplex/mplex.nim index 6277a9de0..b6f5ee7c0 100644 --- a/libp2p/muxers/mplex/mplex.nim +++ b/libp2p/muxers/mplex/mplex.nim @@ -65,6 +65,7 @@ method handle*(m: Mplex) {.async, gcsafe.} = trace "waiting for data" let msg = await m.connection.readMsg() if msg.isNone: + trace "connection EOF" # TODO: allow poll with timeout to avoid using `sleepAsync` await sleepAsync(1.millis) continue diff --git a/libp2p/protocols/secure/secio.nim b/libp2p/protocols/secure/secio.nim index c0a73185b..a1f0a332f 100644 --- a/libp2p/protocols/secure/secio.nim +++ b/libp2p/protocols/secure/secio.nim @@ -68,6 +68,8 @@ type writerCoder: SecureCipher readerCoder: SecureCipher + SecioError* = object of CatchableError + proc init(mac: var SecureMac, hash: string, key: openarray[byte]) = if hash == "SHA256": mac = SecureMac(kind: SecureMacType.Sha256) @@ -313,16 +315,16 @@ proc handshake*(s: Secio, conn: Connection): Future[SecureConnection] {.async.} if len(answer) == 0: trace "Proposal exchange failed", conn = conn - return + raise newException(SecioError, "Proposal exchange failed") if not decodeProposal(answer, remoteNonce, remoteBytesPubkey, remoteExchanges, remoteCiphers, remoteHashes): trace "Remote proposal decoding failed", conn = conn - return + raise newException(SecioError, "Remote proposal decoding failed") if not remotePubkey.init(remoteBytesPubkey): trace "Remote public key incorrect or corrupted", pubkey = remoteBytesPubkey - return + raise newException(SecioError, "Remote public key incorrect or corrupted") remotePeerId = PeerID.init(remotePubkey) @@ -340,7 +342,7 @@ proc handshake*(s: Secio, conn: Connection): Future[SecureConnection] {.async.} let hash = selectBest(order, SecioHashes, remoteHashes) if len(scheme) == 0 or len(cipher) == 0 or len(hash) == 0: trace "No algorithms in common", peer = remotePeerId - return + raise newException(SecioError, "No algorithms in common") trace "Encryption scheme selected", scheme = scheme, cipher = cipher, hash = hash @@ -352,41 +354,40 @@ proc handshake*(s: Secio, conn: Connection): Future[SecureConnection] {.async.} var signature = s.localPrivateKey.sign(localCorpus) var localExchange = createExchange(epubkey, signature.getBytes()) - var remoteExchange = await transactMessage(conn, localExchange) if len(remoteExchange) == 0: trace "Corpus exchange failed", conn = conn - return + raise newException(SecioError, "Corpus exchange failed") if not decodeExchange(remoteExchange, remoteEBytesPubkey, remoteEBytesSig): trace "Remote exchange decoding failed", conn = conn - return + raise newException(SecioError, "Remote exchange decoding failed") if not remoteESignature.init(remoteEBytesSig): trace "Remote signature incorrect or corrupted", signature = toHex(remoteEBytesSig) - return + raise newException(SecioError, "Remote signature incorrect or corrupted") var remoteCorpus = answer & request[4..^1] & remoteEBytesPubkey if not remoteESignature.verify(remoteCorpus, remotePubkey): trace "Signature verification failed", scheme = remotePubkey.scheme, signature = remoteESignature, pubkey = remotePubkey, corpus = remoteCorpus - return + raise newException(SecioError, "Signature verification failed") trace "Signature verified", scheme = remotePubkey.scheme if not remoteEPubkey.eckey.initRaw(remoteEBytesPubkey): trace "Remote ephemeral public key incorrect or corrupted", pubkey = toHex(remoteEBytesPubkey) - return + raise newException(SecioError, "Remote ephemeral public key incorrect or corrupted") var secret = getSecret(remoteEPubkey, ekeypair.seckey) if len(secret) == 0: trace "Shared secret could not be created", pubkeyScheme = remoteEPubkey.scheme, seckeyScheme = ekeypair.seckey.scheme - return + raise newException(SecioError, "Shared secret could not be created") trace "Shared secret calculated", secret = toHex(secret) @@ -416,15 +417,16 @@ proc readLoop(sconn: SecureConnection, stream: BufferStream) {.async.} = try: while not sconn.closed: let msg = await sconn.readMessage() - if msg.len > 0: - await stream.pushTo(msg) + if msg.len == 0: + trace "stream EOF" + return - # tight loop, give a chance for other - # stuff to run as well - await sleepAsync(1.millis) + await stream.pushTo(msg) except CatchableError as exc: trace "exception occured", exc = exc.msg finally: + if not sconn.closed: + await sconn.close() trace "ending secio readLoop", isclosed = sconn.closed() proc handleConn(s: Secio, conn: Connection): Future[Connection] {.async, gcsafe.} = @@ -436,26 +438,34 @@ proc handleConn(s: Secio, conn: Connection): Future[Connection] {.async, gcsafe. var stream = newBufferStream(writeHandler) asyncCheck readLoop(sconn, stream) var secured = newConnection(stream) + secured.peerInfo = PeerInfo.init(sconn.peerInfo.publicKey.get()) + result = secured + secured.closeEvent.wait() .addCallback do (udata: pointer): trace "wrapped connection closed, closing upstream" if not isNil(sconn) and not sconn.closed: asyncCheck sconn.close() - secured.peerInfo = PeerInfo.init(sconn.peerInfo.publicKey.get()) - result = secured - method init(s: Secio) {.gcsafe.} = proc handle(conn: Connection, proto: string) {.async, gcsafe.} = trace "handling connection" - discard await s.handleConn(conn) - trace "connection secured" + try: + discard await s.handleConn(conn) + trace "connection secured" + except CatchableError as exc: + trace "securing connection failed", msg = exc.msg + await conn.close() s.codec = SecioCodec s.handler = handle -method secure*(s: Secio, conn: Connection): Future[Connection] {.gcsafe.} = - result = s.handleConn(conn) +method secure*(s: Secio, conn: Connection): Future[Connection] {.async, gcsafe.} = + try: + result = await s.handleConn(conn) + except CatchableError as exc: + trace "securing connection failed", msg = exc.msg + await conn.close() proc newSecio*(localPrivateKey: PrivateKey): Secio = new result diff --git a/libp2p/switch.nim b/libp2p/switch.nim index cfbebd6f3..358d067c6 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -140,7 +140,8 @@ proc cleanupConn(s: Switch, conn: Connection) {.async, gcsafe.} = s.muxed.del(id) if id in s.connections: - await s.connections[id].close() + if not s.connections[id].closed: + await s.connections[id].close() s.connections.del(id) proc disconnect*(s: Switch, peer: PeerInfo) {.async, gcsafe.} = @@ -166,6 +167,9 @@ proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, g return result = await s.secure(result) # secure the connection + if isNil(result): + return + await s.mux(result) # mux it if possible s.connections[conn.peerInfo.id] = result @@ -212,8 +216,12 @@ proc dial*(s: Switch, # make sure to assign the peer to the connection result.peerInfo = peer result = await s.upgradeOutgoing(result) - result.closeEvent.wait().addCallback do (udata: pointer): - asyncCheck s.cleanupConn(result) + if isNil(result): + continue + + result.closeEvent.wait() + .addCallback do (udata: pointer): + asyncCheck s.cleanupConn(result) break else: trace "Reusing existing connection" @@ -252,7 +260,9 @@ proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} = except CatchableError as exc: trace "exception occured", exc = exc.msg finally: - await conn.close() + if not isNil(conn) and not conn.closed: + await conn.close() + await s.cleanupConn(conn) var startFuts: seq[Future[void]] @@ -346,4 +356,3 @@ proc newSwitch*(peerInfo: PeerInfo, if pubSub.isSome: result.pubSub = pubSub result.mount(pubSub.get()) -