diff --git a/chronos/streams/tlsstream.nim b/chronos/streams/tlsstream.nim index 7010c3a4..8230e5a5 100644 --- a/chronos/streams/tlsstream.nim +++ b/chronos/streams/tlsstream.nim @@ -35,7 +35,7 @@ type RSA, EC TLSResult {.pure.} = enum - Success, Error, EOF + Success, Error, Stopped, WriteEof, ReadEof TLSPrivateKey* = ref object case kind: TLSKeyType @@ -126,6 +126,9 @@ template newTLSStreamProtocolImpl[T](message: T): ref TLSStreamProtocolError = err.errCode = code err +template newTLSUnexpectedProtocolError(): ref TLSStreamProtocolError = + newException(TLSStreamProtocolError, "Unexpected internal error") + proc newTLSStreamProtocolError[T](message: T): ref TLSStreamProtocolError = newTLSStreamProtocolImpl(message) @@ -142,12 +145,14 @@ proc tlsWriteRec(engine: ptr SslEngineContext, sslEngineSendrecAck(engine[], length) return TLSResult.Success except AsyncStreamError as exc: - if writer.state == AsyncStreamState.Running: - writer.state = AsyncStreamState.Error - writer.error = exc + writer.state = AsyncStreamState.Error + writer.error = exc + return TLSResult.Error except CancelledError: if writer.state == AsyncStreamState.Running: writer.state = AsyncStreamState.Stopped + return TLSResult.Stopped + return TLSResult.Error proc tlsWriteApp(engine: ptr SslEngineContext, @@ -157,6 +162,13 @@ proc tlsWriteApp(engine: ptr SslEngineContext, if item.size > 0: var length = 0'u var buf = sslEngineSendappBuf(engine[], length) + if isNil(buf) or (length == 0): + # This situation could happen when connection is closing, no + # application data can be sent, but some can still be received + # (and discarded). + writer.state = AsyncStreamState.Finished + return TLSResult.WriteEof + let toWrite = min(int(length), item.size) copyOut(buf, item, toWrite) if int(length) >= item.size: @@ -180,6 +192,8 @@ proc tlsWriteApp(engine: ptr SslEngineContext, except CancelledError: if writer.state == AsyncStreamState.Running: writer.state = AsyncStreamState.Stopped + return TLSResult.Stopped + return TLSResult.Error proc tlsReadRec(engine: ptr SslEngineContext, @@ -191,17 +205,18 @@ proc tlsReadRec(engine: ptr SslEngineContext, sslEngineRecvrecAck(engine[], uint(res)) if res == 0: sslEngineClose(engine[]) - - return TLSResult.EOF + return TLSResult.ReadEof else: return TLSResult.Success + except AsyncStreamError as exc: + reader.state = AsyncStreamState.Error + reader.error = exc + return TLSResult.Error except CancelledError: if reader.state == AsyncStreamState.Running: reader.state = AsyncStreamState.Stopped - except AsyncStreamError as exc: - if reader.state == AsyncStreamState.Running: - reader.state = AsyncStreamState.Error - reader.error = exc + return TLSResult.Stopped + return TLSResult.Error proc tlsReadApp(engine: ptr SslEngineContext, @@ -215,13 +230,15 @@ proc tlsReadApp(engine: ptr SslEngineContext, except CancelledError: if reader.state == AsyncStreamState.Running: reader.state = AsyncStreamState.Stopped + return TLSResult.Stopped + return TLSResult.Error template readAndReset(fut: untyped) = if fut.finished(): let res = fut.read() case res - of TLSResult.Success: + of TLSResult.Success, TLSResult.WriteEof, TLSResult.Stopped: fut = nil continue of TLSResult.Error: @@ -229,7 +246,7 @@ template readAndReset(fut: untyped) = if loopState == AsyncStreamState.Running: loopState = AsyncStreamState.Error break - of TLSResult.EOF: + of TLSResult.ReadEof: fut = nil if loopState == AsyncStreamState.Running: loopState = AsyncStreamState.Finished @@ -301,14 +318,14 @@ proc tlsLoop*(stream: TLSAsyncStream) {.async.} = if isNil(sendAppFut): if (state and SSL_SENDAPP) == SSL_SENDAPP: - # Application data can be sent over stream. - if not(stream.writer.handshaked): - stream.reader.handshaked = true - stream.writer.handshaked = true - if not(isNil(stream.writer.handshakeFut)): - stream.writer.handshakeFut.complete() - - sendAppFut = tlsWriteApp(engine, stream.writer) + if stream.writer.state == AsyncStreamState.Running: + # Application data can be sent over stream. + if not(stream.writer.handshaked): + stream.reader.handshaked = true + stream.writer.handshaked = true + if not(isNil(stream.writer.handshakeFut)): + stream.writer.handshakeFut.complete() + sendAppFut = tlsWriteApp(engine, stream.writer) else: sendAppFut.readAndReset() @@ -353,8 +370,10 @@ proc tlsLoop*(stream: TLSAsyncStream) {.async.} = of AsyncStreamState.Error: if not(isNil(stream.writer.error)): stream.writer.error - else: + elif not(isNil(stream.reader.error)): newTLSStreamWriteError(stream.reader.error) + else: + newTLSUnexpectedProtocolError() of AsyncStreamState.Finished: let err = engine[].sslEngineLastError() if err != 0: @@ -389,7 +408,8 @@ proc tlsLoop*(stream: TLSAsyncStream) {.async.} = if not(isNil(stream.writer.handshakeFut)): if not(stream.writer.handshakeFut.finished()): stream.writer.handshakeFut.fail( - newTLSStreamProtocolError("Connection with remote peer lost") + newTLSStreamProtocolError( + "Connection to the remote peer has been lost") ) # Completing readers