diff --git a/asyncdispatch2.nimble b/asyncdispatch2.nimble index a6684522..9945022a 100644 --- a/asyncdispatch2.nimble +++ b/asyncdispatch2.nimble @@ -1,5 +1,5 @@ packageName = "asyncdispatch2" -version = "2.1.3" +version = "2.1.4" author = "Status Research & Development GmbH" description = "Asyncdispatch2" license = "Apache License 2.0 or MIT" diff --git a/asyncdispatch2/transports/common.nim b/asyncdispatch2/transports/common.nim index 794716e6..cec63951 100644 --- a/asyncdispatch2/transports/common.nim +++ b/asyncdispatch2/transports/common.nim @@ -86,11 +86,14 @@ type ## Transport's specific exception TransportOsError* = object of TransportError ## Transport's OS specific exception + code*: OSErrorCode TransportIncompleteError* = object of TransportError ## Transport's `incomplete data received` exception TransportLimitError* = object of TransportError ## Transport's `data limit reached` exception TransportAddressError* = object of TransportError + ## Transport's address specific exception + code*: OSErrorCode TransportState* = enum ## Transport's state @@ -290,7 +293,18 @@ template getError*(t: untyped): ref Exception = proc raiseTransportOsError*(err: OSErrorCode) = ## Raises transport specific OS error. var msg = "(" & $int(err) & ") " & osErrorMsg(err) - raise newException(TransportOsError, msg) + var tre = newException(TransportOsError, msg) + tre.code = err + raise tre + +template getTransportOsError*(err: OSErrorCode): ref TransportOsError = + var msg = "(" & $int(err) & ") " & osErrorMsg(err) + var tre = newException(TransportOsError, msg) + tre.code = err + tre + +template getTransportOsError*(err: cint): ref TransportOsError = + getTransportOsError(OSErrorCode(err)) type SeqHeader = object @@ -305,7 +319,10 @@ proc isLiteral*[T](s: seq[T]): bool {.inline.} = when defined(windows): import winlean - const ERROR_OPERATION_ABORTED* = 995 - const ERROR_SUCCESS* = 0 + const + ERROR_OPERATION_ABORTED* = 995 + ERROR_SUCCESS* = 0 + ERROR_CONNECTION_REFUSED* = 1225 + proc cancelIo*(hFile: HANDLE): WINBOOL {.stdcall, dynlib: "kernel32", importc: "CancelIo".} diff --git a/asyncdispatch2/transports/datagram.nim b/asyncdispatch2/transports/datagram.nim index 3f6dcf17..47451da7 100644 --- a/asyncdispatch2/transports/datagram.nim +++ b/asyncdispatch2/transports/datagram.nim @@ -56,7 +56,7 @@ type template setReadError(t, e: untyped) = (t).state.incl(ReadError) - (t).error = newException(TransportOsError, osErrorMsg((e))) + (t).error = getTransportOsError(e) template setWriterWSABuffer(t, v: untyped) = (t).wwsabuf.buf = cast[cstring](v.buf) @@ -85,7 +85,7 @@ when defined(windows): vector.writer.complete() else: transp.state = transp.state + {WritePaused, WriteError} - vector.writer.fail(newException(TransportOsError, osErrorMsg(err))) + vector.writer.fail(getTransportOsError(err)) else: ## Initiation transp.state.incl(WritePending) @@ -114,7 +114,7 @@ when defined(windows): else: transp.state.excl(WritePending) transp.state = transp.state + {WritePaused, WriteError} - vector.writer.fail(newException(TransportOsError, osErrorMsg(err))) + vector.writer.fail(getTransportOsError(err)) else: transp.queue.addFirst(vector) break @@ -297,18 +297,6 @@ when defined(windows): else: result.state.incl(ReadPaused) - # proc close*(transp: DatagramTransport) = - # ## Closes and frees resources of transport ``transp``. - # if ReadClosed notin transp.state and WriteClosed notin transp.state: - # # discard cancelIo(Handle(transp.fd)) - # closeSocket(transp.fd) - # transp.state.incl(WriteClosed) - # transp.state.incl(ReadClosed) - # transp.future.complete() - # if not isNil(transp.udata) and GCUserData in transp.flags: - # GC_unref(cast[ref int](transp.udata)) - # GC_unref(transp) - else: # Linux/BSD/MacOS part @@ -376,8 +364,7 @@ else: if int(err) == EINTR: continue else: - vector.writer.fail(newException(TransportOsError, - osErrorMsg(err))) + vector.writer.fail(getTransportOsError(err)) break else: transp.state.incl(WritePaused) diff --git a/asyncdispatch2/transports/stream.nim b/asyncdispatch2/transports/stream.nim index c48d169b..7f5af885 100644 --- a/asyncdispatch2/transports/stream.nim +++ b/asyncdispatch2/transports/stream.nim @@ -130,7 +130,7 @@ proc localAddress*(transp: StreamTransport): TransportAddress = template setReadError(t, e: untyped) = (t).state.incl(ReadError) - (t).error = newException(TransportOsError, osErrorMsg((e))) + (t).error = getTransportOsError(e) template checkPending(t: untyped) = if not isNil((t).reader): @@ -218,7 +218,7 @@ when defined(windows): else: let v = transp.queue.popFirst() transp.state.incl(WriteError) - v.writer.fail(newException(TransportOsError, osErrorMsg(err))) + v.writer.fail(getTransportOsError(err)) else: ## Initiation transp.state.incl(WritePending) @@ -243,8 +243,7 @@ when defined(windows): else: transp.state.excl(WritePending) transp.state = transp.state + {WritePaused, WriteError} - vector.writer.fail(newException(TransportOsError, - osErrorMsg(err))) + vector.writer.fail(getTransportOsError(err)) else: transp.queue.addFirst(vector) else: @@ -273,8 +272,7 @@ when defined(windows): else: transp.state.excl(WritePending) transp.state = transp.state + {WritePaused, WriteError} - vector.writer.fail(newException(TransportOsError, - osErrorMsg(err))) + vector.writer.fail(getTransportOsError(err)) else: transp.queue.addFirst(vector) break @@ -417,12 +415,16 @@ when defined(windows): toSockAddr(address.address, address.port, saddr, slen) sock = createAsyncSocket(address.address.getDomain(), SockType.SOCK_STREAM, Protocol.IPPROTO_TCP) + if sock == asyncInvalidSocket: - result.fail(newException(TransportOsError, osErrorMsg(osLastError()))) + retFuture.fail(getTransportOsError(OSErrorCode(wsaGetLastError()))) + return retFuture if not bindToDomain(sock, address.address.getDomain()): + let err = wsaGetLastError() sock.closeSocket() - result.fail(newException(TransportOsError, osErrorMsg(osLastError()))) + retFuture.fail(getTransportOsError(err)) + return retFuture proc continuation(udata: pointer) = var ovl = cast[RefCustomOverlapped](udata) @@ -432,16 +434,14 @@ when defined(windows): cint(SO_UPDATE_CONNECT_CONTEXT), nil, SockLen(0)) != 0'i32: sock.closeSocket() - retFuture.fail(newException(TransportOsError, - osErrorMsg(osLastError()))) + retFuture.fail(getTransportOsError(wsaGetLastError())) else: retFuture.complete(newStreamSocketTransport(povl.data.fd, bufferSize, child)) else: sock.closeSocket() - retFuture.fail(newException(TransportOsError, - osErrorMsg(ovl.data.errCode))) + retFuture.fail(getTransportOsError(ovl.data.errCode)) GC_unref(ovl) povl = RefCustomOverlapped() @@ -457,7 +457,7 @@ when defined(windows): if int32(err) != ERROR_IO_PENDING: GC_unref(povl) sock.closeSocket() - retFuture.fail(newException(TransportOsError, osErrorMsg(err))) + retFuture.fail(getTransportOsError(err)) return retFuture proc acceptLoop(udata: pointer) {.gcsafe, nimcall.} = @@ -477,8 +477,9 @@ when defined(windows): cint(SO_UPDATE_ACCEPT_CONTEXT), addr server.sock, SockLen(sizeof(SocketHandle))) != 0'i32: + let err = OSErrorCode(wsaGetLastError()) server.asock.closeSocket() - raiseTransportOsError(osLastError()) + raiseTransportOsError(err) else: if not isNil(server.init): var transp = server.init(server, server.asock) @@ -495,8 +496,9 @@ when defined(windows): server.asock.closeSocket() break else: + let err = OSErrorCode(wsaGetLastError()) server.asock.closeSocket() - raiseTransportOsError(osLastError()) + raiseTransportOsError(err) else: ## Initiation if server.status in {ServerStatus.Stopped, ServerStatus.Closed}: @@ -507,7 +509,7 @@ when defined(windows): server.asock = createAsyncSocket(server.domain, SockType.SOCK_STREAM, Protocol.IPPROTO_TCP) if server.asock == asyncInvalidSocket: - raiseTransportOsError(osLastError()) + raiseTransportOsError(OSErrorCode(wsaGetLastError())) var dwBytesReceived = DWORD(0) let dwReceiveDataLength = DWORD(0) @@ -588,8 +590,7 @@ else: if int(err) == EINTR: continue else: - vector.writer.fail(newException(TransportOsError, - osErrorMsg(err))) + vector.writer.fail(getTransportOsError(err)) else: let res = sendfile(int(fd), cast[int](vector.buflen), int(vector.offset), @@ -605,8 +606,7 @@ else: if int(err) == EINTR: continue else: - vector.writer.fail(newException(TransportOsError, - osErrorMsg(err))) + vector.writer.fail(getTransportOsError(err)) break else: transp.state.incl(WritePaused) @@ -686,7 +686,7 @@ else: sock = createAsyncSocket(address.address.getDomain(), SockType.SOCK_STREAM, Protocol.IPPROTO_TCP) if sock == asyncInvalidSocket: - retFuture.fail(newException(TransportOsError, osErrorMsg(osLastError()))) + retFuture.fail(getTransportOsError(osLastError())) return retFuture proc continuation(udata: pointer) = @@ -696,13 +696,11 @@ else: fd.removeWriter() if not fd.getSocketError(err): closeSocket(fd) - retFuture.fail(newException(TransportOsError, - osErrorMsg(osLastError()))) + retFuture.fail(getTransportOsError(osLastError())) return if err != 0: closeSocket(fd) - retFuture.fail(newException(TransportOsError, - osErrorMsg(OSErrorCode(err)))) + retFuture.fail(getTransportOsError(OSErrorCode(err))) return retFuture.complete(newStreamSocketTransport(fd, bufferSize, child)) @@ -721,7 +719,7 @@ else: break else: sock.closeSocket() - retFuture.fail(newException(TransportOsError, osErrorMsg(err))) + retFuture.fail(getTransportOsError(err)) break return retFuture @@ -782,7 +780,7 @@ proc stop*(server: StreamServer) = proc join*(server: StreamServer): Future[void] = ## Waits until ``server`` is not closed. - var retFuture = newFuture[void]("streamserver.join") + var retFuture = newFuture[void]("stream.server.join") proc continuation(udata: pointer) = retFuture.complete() if not server.loopFuture.finished: server.loopFuture.addCallback(continuation) diff --git a/tests/teststream.nim b/tests/teststream.nim index 8155ec3f..7401fbf0 100644 --- a/tests/teststream.nim +++ b/tests/teststream.nim @@ -632,6 +632,15 @@ proc test14(): Future[int] {.async.} = await server.join() result = subres +proc testConnectionRefused(): Future[bool] {.async.} = + try: + var transp = await connect(initTAddress("127.0.0.1:1")) + except TransportOsError as e: + when defined(windows): + result = (int(e.code) == ERROR_CONNECTION_REFUSED) + else: + result = (int(e.code) == ECONNREFUSED) + when isMainModule: const m1 = "readLine() multiple clients with messages (" & $ClientsCount & @@ -653,6 +662,7 @@ when isMainModule: m12 = "readUntil() unexpected disconnect test" m13 = "readLine() unexpected disconnect empty string test" m14 = "Closing socket while operation pending test (issue #8)" + m15 = "Connection refused test" suite "Stream Transport test suite": test m8: check waitFor(test8()) == 1 @@ -682,3 +692,5 @@ when isMainModule: check waitFor(test6()) == ClientsCount * MessagesCount test m4: check waitFor(test4()) == FilesCount + test m15: + check waitFor(testConnectionRefused()) == true