From a0c724e9d8e0cce3dc06ed62a7837ac318371258 Mon Sep 17 00:00:00 2001 From: cheatfate Date: Fri, 15 Jun 2018 03:28:02 +0300 Subject: [PATCH] Fix Windows/Unix behavior on datagram CONNRESET. Add test for CONNRESET --- asyncdispatch2/asyncloop.nim | 19 ++++++++------ asyncdispatch2/transports/datagram.nim | 34 +++++++++++++++++--------- asyncdispatch2/transports/stream.nim | 4 +-- tests/testdatagram.nim | 19 ++++++++++++++ 4 files changed, 55 insertions(+), 21 deletions(-) diff --git a/asyncdispatch2/asyncloop.nim b/asyncdispatch2/asyncloop.nim index c805220..d211209 100644 --- a/asyncdispatch2/asyncloop.nim +++ b/asyncdispatch2/asyncloop.nim @@ -227,7 +227,6 @@ template processCallbacks(loop: untyped) = when defined(windows) or defined(nimdoc): import winlean, sets, hashes - type WSAPROC_TRANSMITFILE = proc(hSocket: SocketHandle, hFile: Handle, nNumberOfBytesToWrite: DWORD, @@ -341,10 +340,10 @@ when defined(windows) or defined(nimdoc): loop.processCallbacks() proc getFunc(s: SocketHandle, fun: var pointer, guid: var GUID): bool = - var bytesRet: Dword + var bytesRet: DWORD fun = nil result = WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, addr guid, - sizeof(GUID).Dword, addr fun, sizeof(pointer).Dword, + sizeof(GUID).DWORD, addr fun, sizeof(pointer).DWORD, addr bytesRet, nil, nil) == 0 proc initAPI() = @@ -360,26 +359,30 @@ when defined(windows) or defined(nimdoc): if wsaStartup(0x0202'i16, addr wsa) != 0: raiseOSError(osLastError()) - let sock = winlean.socket(winlean.AF_INET, 1 , 6) + let sock = winlean.socket(winlean.AF_INET, 1, 6) if sock == INVALID_SOCKET: raiseOSError(osLastError()) var funcPointer: pointer = nil if not getFunc(sock, funcPointer, WSAID_CONNECTEX): + let err = osLastError() close(sock) - raiseOSError(osLastError()) + raiseOSError(err) loop.connectEx = cast[WSAPROC_CONNECTEX](funcPointer) if not getFunc(sock, funcPointer, WSAID_ACCEPTEX): + let err = osLastError() close(sock) - raiseOSError(osLastError()) + raiseOSError(err) loop.acceptEx = cast[WSAPROC_ACCEPTEX](funcPointer) if not getFunc(sock, funcPointer, WSAID_GETACCEPTEXSOCKADDRS): + let err = osLastError() close(sock) - raiseOSError(osLastError()) + raiseOSError(err) loop.getAcceptExSockAddrs = cast[WSAPROC_GETACCEPTEXSOCKADDRS](funcPointer) if not getFunc(sock, funcPointer, WSAID_TRANSMITFILE): + let err = osLastError() close(sock) - raiseOSError(osLastError()) + raiseOSError(err) loop.transmitFile = cast[WSAPROC_TRANSMITFILE](funcPointer) close(sock) diff --git a/asyncdispatch2/transports/datagram.nim b/asyncdispatch2/transports/datagram.nim index d6b58c8..18749d3 100644 --- a/asyncdispatch2/transports/datagram.nim +++ b/asyncdispatch2/transports/datagram.nim @@ -63,6 +63,9 @@ template setWriterWSABuffer(t, v: untyped) = (t).wwsabuf.len = cast[int32](v.buflen) when defined(windows): + const + IOC_VENDOR = DWORD(0x18000000) + SIO_UDP_CONNRESET = DWORD(winlean.IOC_IN) or IOC_VENDOR or DWORD(12) proc writeDatagramLoop(udata: pointer) = var bytesCount: int32 @@ -213,10 +216,10 @@ when defined(windows): localSock = createAsyncSocket(Domain.AF_INET6, SockType.SOCK_DGRAM, Protocol.IPPROTO_UDP) if localSock == asyncInvalidSocket: - raiseOsError(osLastError()) + raiseTransportOsError(osLastError()) else: if not setSocketBlocking(SocketHandle(sock), false): - raiseOsError(osLastError()) + raiseTransportOsError(osLastError()) localSock = sock register(localSock) @@ -226,7 +229,15 @@ when defined(windows): let err = osLastError() if sock == asyncInvalidSocket: closeAsyncSocket(localSock) - raiseOsError(err) + raiseTransportOsError(err) + + ## Fix for Q263823. + var bytesRet: DWORD + var bval = WINBOOL(0) + if WSAIoctl(SocketHandle(localSock), SIO_UDP_CONNRESET, addr bval, + sizeof(WINBOOL).DWORD, nil, DWORD(0), + addr bytesRet, nil, nil) != 0: + raiseTransportOsError(osLastError()) if local.port != Port(0): var saddr: Sockaddr_storage @@ -237,7 +248,7 @@ when defined(windows): let err = osLastError() if sock == asyncInvalidSocket: closeAsyncSocket(localSock) - raiseOsError(err) + raiseTransportOsError(err) result.local = local else: var saddr: Sockaddr_storage @@ -253,7 +264,7 @@ when defined(windows): let err = osLastError() if sock == asyncInvalidSocket: closeAsyncSocket(localSock) - raiseOsError(err) + raiseTransportOsError(err) if remote.port != Port(0): var saddr: Sockaddr_storage @@ -264,7 +275,7 @@ when defined(windows): let err = osLastError() if sock == asyncInvalidSocket: closeAsyncSocket(localSock) - raiseOsError(err) + raiseTransportOsError(err) result.remote = remote result.fd = localSock @@ -299,6 +310,7 @@ when defined(windows): GC_unref(transp) else: + # Linux/BSD/MacOS part proc readDatagramLoop(udata: pointer) = var @@ -403,10 +415,10 @@ else: localSock = createAsyncSocket(Domain.AF_INET6, SockType.SOCK_DGRAM, Protocol.IPPROTO_UDP) if localSock == asyncInvalidSocket: - raiseOsError(osLastError()) + raiseTransportOsError(osLastError()) else: if not setSocketBlocking(SocketHandle(sock), false): - raiseOsError(osLastError()) + raiseTransportOsError(osLastError()) localSock = sock register(localSock) @@ -416,7 +428,7 @@ else: let err = osLastError() if sock == asyncInvalidSocket: closeAsyncSocket(localSock) - raiseOsError(err) + raiseTransportOsError(err) if local.port != Port(0): var saddr: Sockaddr_storage @@ -427,7 +439,7 @@ else: let err = osLastError() if sock == asyncInvalidSocket: closeAsyncSocket(localSock) - raiseOsError(err) + raiseTransportOsError(err) result.local = local if remote.port != Port(0): @@ -439,7 +451,7 @@ else: let err = osLastError() if sock == asyncInvalidSocket: closeAsyncSocket(localSock) - raiseOsError(err) + raiseTransportOsError(err) result.remote = remote result.fd = localSock diff --git a/asyncdispatch2/transports/stream.nim b/asyncdispatch2/transports/stream.nim index f54e6fe..6bb371a 100644 --- a/asyncdispatch2/transports/stream.nim +++ b/asyncdispatch2/transports/stream.nim @@ -111,7 +111,7 @@ proc remoteAddress*(transp: StreamTransport): TransportAddress = var slen = SockLen(sizeof(saddr)) if getpeername(SocketHandle(transp.fd), cast[ptr SockAddr](addr saddr), addr slen) != 0: - raiseOsError(osLastError()) + raiseTransportOsError(osLastError()) fromSockAddr(saddr, slen, transp.remote.address, transp.remote.port) result = transp.remote @@ -124,7 +124,7 @@ proc localAddress*(transp: StreamTransport): TransportAddress = var slen = SockLen(sizeof(saddr)) if getsockname(SocketHandle(transp.fd), cast[ptr SockAddr](addr saddr), addr slen) != 0: - raiseOsError(osLastError()) + raiseTransportOsError(osLastError()) fromSockAddr(saddr, slen, transp.local.address, transp.local.port) result = transp.local diff --git a/tests/testdatagram.nim b/tests/testdatagram.nim index 8bc0fee..7262088 100644 --- a/tests/testdatagram.nim +++ b/tests/testdatagram.nim @@ -429,6 +429,23 @@ proc test3(bounded: bool): Future[int] {.async.} = for i in 0..